<a href="https://colab.research.google.com/github/KravitzLab/FED3Analyses/blob/PsygeneL0-to-L1-QC/PsygeneBEAML0_to_L1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# This is a notebook for quality control and renaming BEAM files from L0 to L1 for Psygene.  
<br>
<img src="https://github.com/KravitzLab/KreedLabWiki/blob/main/images/ChatGPT%20Image%20Apr%2020,%202025,%2004_05_24%20PM.png?raw=true" width="300" />

Updated: 07 28 25



In [None]:
# @title Import libraries
# Install Required Packages (if missing) ---
import importlib.util
import subprocess
import sys

required_packages = {
    "ipywidgets": "ipywidgets",
    "openpyxl": "openpyxl"
}

for pkg, name in required_packages.items():
    if importlib.util.find_spec(pkg) is None:
        print(f"Installing {name}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", name])

# ---Imports ---
import os, re, zipfile, shutil, warnings, io
import pandas as pd
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.dates as mdates
from matplotlib.ticker import FuncFormatter
from datetime import datetime, timedelta
from collections import defaultdict
from scipy.optimize import curve_fit
from statsmodels.formula.api import ols
import statsmodels.api as sm
import ipywidgets as widgets
from IPython.display import display, clear_output
from google.colab import files, output

# ---  Enable custom widgets & suppress warnings ---
output.enable_custom_widget_manager()
warnings.filterwarnings('ignore')  # ‚ùó Use with caution

print(" All packages ready and environment set up.")




In [None]:
# @title Upload BEAM L0 files and Key


# --- Upload Files ---
print("Upload your ZIP file containing CSVs")
zip_upload = files.upload()

print("Upload your XLSX key file")
xlsx_upload = files.upload()

zip_filename = list(zip_upload.keys())[0]
xlsx_filename = list(xlsx_upload.keys())[0]
key_df = pd.read_excel(xlsx_filename)

# --- Step 3: Extract, match by BEAM within window, and group ---

import re, io, zipfile
from datetime import datetime
from collections import defaultdict

MATCH_WINDOW_DAYS = 15  # change if needed

# 1) Normalize the key
key_df = key_df.copy()

def normalize_beam_value(x):
    if pd.isna(x):
        return None
    m = re.search(r"(\d{1,4})", str(x))
    return m.group(1).zfill(3) if m else None

if "BEAM" not in key_df.columns:
    raise ValueError("Key file is missing a 'BEAM' column.")

key_df["BEAM_norm"] = key_df["BEAM"].apply(normalize_beam_value)

# Parse/normalize FED_StartDate
if "FED_StartDate" not in key_df.columns:
    raise ValueError("Key file is missing a 'FED_StartDate' column.")
key_df["FED_StartDate_parsed"] = pd.to_datetime(key_df["FED_StartDate"], errors="coerce")

# 2) Helpers for filename parsing (supports several patterns)
def parse_beam_and_date_from_name(name: str):
    """
    Return (beam_id_3digit, file_date_dateobj) or (None, None) if no match.
    Tries multiple filename patterns.
    """
    base = Path(name).name
    patterns = [
        (r"BEAM(\d{3})_(\d{10})\.csv", "%Y%m%d%H"),      # BEAM023_2025030800.csv
        (r"BEAM(\d{1,4})_(\d{8})\.csv", "%Y%m%d"),       # BEAM23_20250308.csv
        (r"BEAM\s*(\d{1,4}).*?(\d{4}-\d{2}-\d{2})", "%Y-%m-%d"),  # BEAM 23 ... 2025-03-08.csv
        (r"(\d{8})_BEAM(\d{1,4})\.csv", "%Y%m%d"),       # 20250308_BEAM23.csv
    ]
    for pat, fmt in patterns:
        m = re.search(pat, base, flags=re.IGNORECASE)
        if not m:
            continue
        # determine which group is beam/date based on pattern orientation
        if pat.startswith(r"(\d{8})_BEAM"):
            date_raw, beam_raw = m.group(1), m.group(2)
        else:
            beam_raw, date_raw = m.group(1), m.group(2)
        try:
            beam_id = str(int(beam_raw)).zfill(3)
            file_dt = datetime.strptime(date_raw, fmt).date()  # date-only
            return beam_id, file_dt
        except Exception:
            continue
    return None, None

# 3) Extract/scan the ZIP, match to key within window
beam_data = defaultdict(list)   # for the big BEAM_data table
raw_beam_files = {}             # filename-> DataFrame (unique per match)
unmatched_reasons = []          # collect diagnostics

with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    for file in zip_ref.namelist():
        if not file.lower().endswith(".csv"):
            continue

        beam_id, file_date = parse_beam_and_date_from_name(file)
        if beam_id is None or file_date is None:
            unmatched_reasons.append((file, "Filename not parsed"))
            continue

        # load CSV
        with zip_ref.open(file) as f:
            df = pd.read_csv(f)
        df["BEAM"] = beam_id
        df["file_date"] = pd.to_datetime(file_date)

        # For the big table later
        df["source_file"] = file
        beam_data[int(beam_id)].append(df.copy())

        # All possible key matches for this BEAM
        key_matches = key_df[key_df["BEAM_norm"] == beam_id].copy()
        key_matches = key_matches.dropna(subset=["FED_StartDate_parsed"])

        if key_matches.empty:
            unmatched_reasons.append((file, f"No BEAM {beam_id} found in key"))
            continue

        # Date-only comparison, allow multiple cohort matches
        matched_any = False
        for _, row in key_matches.iterrows():
            start_date = pd.to_datetime(row["FED_StartDate_parsed"]).date()
            if start_date is None:
                continue
            if abs((file_date - start_date).days) <= MATCH_WINDOW_DAYS:
                df_match = df.copy()
                # Attach cohort columns if present
                if "Mouse_ID" in row:
                    df_match["Mouse_ID"] = row["Mouse_ID"]
                if "Cohort" in row:
                    df_match["Cohort"] = row["Cohort"]

                # Store under a unique key to avoid overwriting if multiple matches
                unique_key = f"{file}__{row.get('Mouse_ID', 'unknown')}"
                raw_beam_files[unique_key] = df_match
                matched_any = True

        if not matched_any:
            # Report nearest distance for debugging
            dists = (key_matches["FED_StartDate_parsed"].dt.date.apply(
                lambda d: abs((file_date - d).days)))
            nearest = int(dists.min()) if len(dists) else None
            unmatched_reasons.append((file, f"No key date within {MATCH_WINDOW_DAYS} days (nearest={nearest})"))

# 4) Build BEAM_data and merged outputs (unchanged behavior, but robust to multiples)
# Combine all matched dfs
if len(raw_beam_files) == 0:
    raise SystemExit("No files matched within the window. See diagnostics printed above.")

all_matched_dfs = list(raw_beam_files.values())
BEAM_data = pd.concat(all_matched_dfs, ignore_index=True)

# Group by (BEAM, Mouse_ID) and concatenate
beam_grouped = defaultdict(list)
for fname, df in raw_beam_files.items():
    beam_id = df["BEAM"].iloc[0]
    mouse_id = df["Mouse_ID"].iloc[0] if "Mouse_ID" in df.columns else "unknown"
    group_key = (beam_id, mouse_id)
    beam_grouped[group_key].append(df)

merged_beam_files = {}
for (beam_id, mouse_id), dfs in beam_grouped.items():
    merged_df = pd.concat(dfs, ignore_index=True)
    # Convert datetime column if present
    if "datetime" in merged_df.columns:
        merged_df["datetime"] = pd.to_datetime(merged_df["datetime"], errors="coerce")
        merged_df.sort_values("datetime", inplace=True)
    fname = f"BEAM{str(beam_id).zfill(3)}_{mouse_id}.csv"
    merged_beam_files[fname] = merged_df

# Also update BEAM_data index if 'datetime' exists
if "datetime" in BEAM_data.columns:
    BEAM_data["datetime"] = pd.to_datetime(BEAM_data["datetime"], errors="coerce")
    BEAM_data.set_index("datetime", inplace=True)

# 5) Reporting
print(f"Combined {len(all_matched_dfs)} matched files into {len(merged_beam_files)} unique device+mouse merged files.")
print(f"Final BEAM_data shape: {BEAM_data.shape}")
if unmatched_reasons:
    print("\nUnmatched/diagnostic reasons (first 25 shown):")
    for fn, reason in unmatched_reasons[:25]:
        print(f"  - {fn}: {reason}")

In [None]:
# @title Plot all files
# Get the number of unique Animal_IDs
num_animals = BEAM_data["Mouse_ID"].nunique()

# Generate a color palette with enough distinct colors
palette = sns.color_palette("husl", num_animals)

plt.figure(figsize=(16, 6))
sns.lineplot(
    data=BEAM_data,
    x="datetime",
    y="activity_percent",
    hue="Mouse_ID",
    palette=palette,
    alpha=0.7,
    linewidth=0.8,
)
plt.xlabel("Date")
plt.ylabel("Activity (%)")
plt.title("Activity Over Time")
sns.despine()
plt.legend().remove()
plt.tight_layout()
plt.show()

In [None]:
# @title Plot individual mice


unique_subjects = list(BEAM_data['Mouse_ID'].dropna().unique())
subject_indices = {i: subj for i, subj in enumerate(unique_subjects)}

plot_output = widgets.Output()

def plot_subject(i):
    subject = subject_indices[i]
    df = BEAM_data[BEAM_data['Mouse_ID'] == subject].copy()
    df = df.reset_index()

    with plot_output:
        plot_output.clear_output(wait=True)
        if df.empty:
            print(f"No data for subject: {subject}")
            return

        # Compute duration string
        duration = df['datetime'].max() - df['datetime'].min()
        days = duration.days
        hours = duration.seconds // 3600
        duration_str = f"{days} days, {hours} hours"

        # Compute average activity
        avg_activity = df["activity_percent"].dropna().mean() * 100
        avg_str = f"{avg_activity:.1f}%"

        # Plot
        fig, ax = plt.subplots(figsize=(10, 3))
        sns.lineplot(data=df, x="datetime", y="activity_percent", alpha=0.5, ax=ax)
        ax.set_ylabel("Activity")
        ax.set_xlabel("Time")
        ax.set_title(f"Mouse_ID: {subject}    Duration: {duration_str}    Avg Activity: {avg_str}")
        ax.xaxis.set_major_locator(mdates.AutoDateLocator())
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d\n%H:%M'))
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
        sns.despine()
        plt.tight_layout()
        plt.show()

slider = widgets.IntSlider(
    min=0,
    max=len(unique_subjects) - 1,
    step=1,
    description="Subject Index:",
    layout=widgets.Layout(width='500px'),
    style={'handle_color': '#444'}
)

subject_label = widgets.HTML(
    value=f"<b>Subject:</b> {subject_indices[0]}",
    layout=widgets.Layout(margin='0 0 10px 0')
)

def on_slider_change(change):
    i = change['new']
    subject_label.value = f"<b>Subject:</b> {subject_indices[i]}"
    plot_subject(i)

slider.observe(on_slider_change, names='value')

ui = widgets.VBox([subject_label, slider, plot_output])
display(ui)
plot_subject(0)

# Compute average activity and lux per subject
activity_summary = (
    BEAM_data.groupby("Mouse_ID")[["activity_percent", "lux"]]
    .mean()
    .dropna()
    .reset_index()
)

# Convert activity to percentage and round both columns
activity_summary["Avg_Activity (%)"] = (activity_summary["activity_percent"] * 100).round(1)
activity_summary["Avg_Lux"] = activity_summary["lux"].round(1)

# Keep only necessary columns and sort
activity_summary = activity_summary[["Mouse_ID", "Avg_Activity (%)", "Avg_Lux"]]




In [None]:
# @title QC and manual file inclution

MIN_AVG_ACTIVITY = 5.0  # %
MAX_ZERO_DURATION = timedelta(hours=8)
MAX_GAP_DURATION = timedelta(hours=2)

passed_files = []
flagged_files = {}
print(f"Running QC on {len(merged_beam_files)} merged files...")
def has_long_zero_percent(df, threshold=MAX_ZERO_DURATION):
    zero_mask = df["activity_percent"] == 0
    if zero_mask.sum() == 0:
        return False
    df = df.copy()
    df["gap"] = zero_mask.astype(int)
    df["block"] = (df["gap"].diff(1) != 0).cumsum() * df["gap"]
    for _, group in df[df["block"] > 0].groupby("block"):
        duration = group.index[-1] - group.index[0]
        if duration >= threshold:
            return True
    return False

flagged_files = {}
passed_files = []

for fname, df in merged_beam_files.items():
    df = df.copy()

    if "datetime" not in df.columns:
        if df.index.name == "datetime":
            df = df.reset_index()
        else:
            raise KeyError(f"'datetime' column not found in file: {fname}")

    df["datetime"] = pd.to_datetime(df["datetime"])
    df.set_index("datetime", inplace=True)
    df.sort_index(inplace=True)

    reasons = []

    avg_pct = df["activity_percent"].mean() * 100
    if avg_pct < MIN_AVG_ACTIVITY:
        reasons.append(f"Low avg activity ({avg_pct:.2f}%)")

    gap = df.index.to_series().diff().max()
    if gap > MAX_GAP_DURATION:
        gap_minutes = round(gap.total_seconds() / 60)
        reasons.append(f"Data gap >2hr (max: {gap_minutes} min)")

    if has_long_zero_percent(df):
        reasons.append("Zero activity >8hrs")

    if reasons:
        flagged_files[fname] = reasons
    else:
        passed_files.append(fname)

# üëÅÔ∏è‚Äçüó®Ô∏è Manual Review with Slider
flagged_fnames = list(flagged_files.keys())
approved_files = passed_files.copy()
override_files = []
rejected_files = []
inclusion_map = {}
decision_log = {}

if flagged_fnames:
    print(f"\n{len(flagged_fnames)} files flagged. Launching manual review...")

    plot_output = widgets.Output()
    button_output = widgets.Output()
    decision_log_output = widgets.Output()

    def update_plot(change=None):
      fname = flagged_fnames[slider.value]
      df = merged_beam_files[fname].copy()
      df = df.reset_index() if "datetime" not in df.columns else df.copy()
      df["datetime"] = pd.to_datetime(df["datetime"])

      with plot_output:
          clear_output(wait=True)
          plt.figure(figsize=(10, 4))
          plt.plot(df["datetime"], df["activity_percent"] * 100, label="Activity %")
          plt.title(fname)
          plt.ylabel("Activity %")
          plt.xlabel("Time")
          plt.grid(True)
          plt.xticks(rotation=45)
          plt.tight_layout()
          plt.show()

      with button_output:
            clear_output(wait=True)
            print(f"\nReviewing file {slider.value + 1} of {len(flagged_fnames)}: {fname}")
            print("Reason(s):", " | ".join(flagged_files[fname]))
            status = inclusion_map.get(fname, None)
            print("Current decision:",
                  "Included" if status else "Excluded" if status is False else "Undecided")

            include_btn = widgets.Button(description="Include", button_style='success')
            exclude_btn = widgets.Button(description="Exclude", button_style='danger')

            def handle_decision(decision):
                if fname in approved_files: approved_files.remove(fname)
                if fname in override_files: override_files.remove(fname)
                if fname in rejected_files: rejected_files.remove(fname)

                inclusion_map[fname] = decision
                if decision:
                    approved_files.append(fname)
                    override_files.append(fname)
                else:
                    rejected_files.append(fname)

                log_decision(fname, decision)
                if slider.value < slider.max:
                    slider.value += 1
                else:
                    print("Review complete.")

            include_btn.on_click(lambda _: handle_decision(True))
            exclude_btn.on_click(lambda _: handle_decision(False))
            display(widgets.HBox([include_btn, exclude_btn]))

    def log_decision(fname, decision):
        decision_log[fname] = f"{'Included' if decision else 'Excluded'}: {fname}"
        with decision_log_output:
            clear_output(wait=True)
            for line in decision_log.values():
                print(line)

    slider = widgets.IntSlider(min=0, max=len(flagged_fnames) - 1, step=1, description="File")
    slider.observe(update_plot, names='value')

    display(slider, plot_output, button_output, widgets.Label("Decision log:"), decision_log_output)
    update_plot()
else:
    print("No files flagged. All passed QC.")


In [None]:
#@title Renaming approved files and updating Key

# ---------- Imports ----------
import os, shutil, zipfile, re
import pandas as pd

# Optional: Colab download helper
try:
    from google.colab import files as colab_files
except Exception:
    colab_files = None

# ---------- Helper: normalize BEAM to 3-digit digits-only string ----------
def normalize_beam_value(x):
    """
    Return a 3-digit, digits-only BEAM string (e.g., 12 -> '012', 'BEAM 12' -> '012').
    Returns None for missing/unparseable values.
    """
    if pd.isna(x):
        return None
    s = str(x)
    m = re.search(r"(\d+)", s)
    if not m:
        return None
    d = m.group(1)              # digits only
    d = d.lstrip("0") or "0"    # keep '0' if all zeros
    return d.zfill(3)

# ---------- Clean and prepare output directory ----------
renamed_dir = "renamed_files"
if os.path.exists(renamed_dir):
    shutil.rmtree(renamed_dir)
os.makedirs(renamed_dir)

flagged_log = []
skipped_files = []

# ---------- KEY NORMALIZATION ----------
key_df = key_df.copy()
key_df["BEAM_norm"] = key_df["BEAM"].apply(normalize_beam_value)
key_df = key_df.dropna(subset=["BEAM_norm"])  # drop rows without a usable BEAM
key_index = key_df.set_index("BEAM_norm")

# Create column to store review notes
session_col = "BEAM_EX"
if session_col not in key_df.columns:
    key_df[session_col] = pd.NA

# ---------- Rename only QC-approved merged files ----------
for fname, df in merged_beam_files.items():
    if fname not in approved_files:
        continue  # skip excluded or unreviewed files

    try:
        beam_id_norm = normalize_beam_value(df["BEAM"].iloc[0])
        if beam_id_norm is None:
            raise ValueError("BEAM missing/unparseable")

        mouse_id = str(df["Mouse_ID"].iloc[0])

        # file_date may be per-row; coerce and take earliest valid date
        file_dates = pd.to_datetime(df["file_date"], errors="coerce")
        if file_dates.notna().any():
            date_fmt = file_dates.min().strftime("%Y%m%d")
        else:
            raise ValueError("file_date missing/unparseable")

    except Exception as e:
        skipped_files.append((fname, f"Metadata extraction error: {e}"))
        continue

    if beam_id_norm not in key_index.index:
        skipped_files.append((fname, f"BEAM ID not found in key (searched='{beam_id_norm}')"))
        continue

    # If you also want BEAM in the new filename, use:
    # new_fname = f"{mouse_id}_BEAM{beam_id_norm}_{date_fmt}.csv"
    new_fname = f"{mouse_id}_BEAM_{date_fmt}.csv"

    df.to_csv(os.path.join(renamed_dir, new_fname), index=False)

# ---------- Update flagged notes in key (use normalized BEAM) ----------
for fname, reason in flagged_files.items():
    if fname not in merged_beam_files:
        continue
    df = merged_beam_files[fname]
    beam_id_norm = normalize_beam_value(df["BEAM"].iloc[0])
    if beam_id_norm is None:
        continue
    key_df.loc[key_df["BEAM_norm"] == beam_id_norm, session_col] = " | ".join(reason)
    flagged_log.append((beam_id_norm, " | ".join(reason)))

# ---------- Identify Animal_IDs with no matching file ----------
no_file_log = []
key_modified = False

merged_keys = {
    (normalize_beam_value(df["BEAM"].iloc[0]), str(df["Mouse_ID"].iloc[0]))
    for df in merged_beam_files.values()
}

for _, row in key_df.iterrows():
    beam_id_norm = row["BEAM_norm"]
    mouse_id = str(row["Mouse_ID"])
    if (beam_id_norm, mouse_id) not in merged_keys:
        condition = (key_df["BEAM_norm"] == beam_id_norm) & (key_df["Mouse_ID"] == mouse_id)
        key_df.loc[condition, session_col] = "no file"
        no_file_log.append((beam_id_norm, mouse_id))
        key_modified = True

# ---------- Extract base names from uploaded files (robust) ----------
zip_base = os.path.splitext(os.path.basename(str(zip_filename))) if 'zip_filename' in globals() else ("output", ".zip")
zip_base = zip_base[0]
key_base = os.path.splitext(os.path.basename(str(xlsx_filename))) if 'xlsx_filename' in globals() else ("key", ".xlsx")
key_base = key_base[0]

# ---------- Save updated key (keeps BEAM_norm; you can drop original BEAM if you want) ----------
key_out_path = f"{key_base}_updated.xlsx"
key_df.to_excel(key_out_path, index=False)

# ---------- Zip renamed files ----------
zip_out_path = f"{zip_base.replace('_L0', '_L1')}.zip"
with zipfile.ZipFile(zip_out_path, "w") as zipf:
    for f in os.listdir(renamed_dir):
        zipf.write(os.path.join(renamed_dir, f), arcname=f)

# ---------- Summary report ----------
print(f"\n{len(no_file_log)} Mouse_IDs in key had no merged BEAM file:")
for beam, mouse in no_file_log:
    print(f"  - BEAM {beam}, Mouse {mouse}")

print(f"\nRenamed {len(os.listdir(renamed_dir))} approved merged files.")
print(f"Skipped {len(skipped_files)} files:")
for fn, reason in skipped_files:
    print(f"  - {fn}: {reason}")
print(f"Flagged {len(flagged_log)} files updated in key.")

# ---------- Apply Excel styling (red/orange fills) ----------
# ---------- Apply Excel styling (strict rules) ----------
try:
    from openpyxl import load_workbook
    from openpyxl.styles import PatternFill

    # Build helper to normalize BEAM
    def _norm_beam(x):
        if pd.isna(x): return None
        m = re.search(r"(\d+)", str(x))
        if not m: return None
        d = m.group(1)
        d = d.lstrip("0") or "0"
        return d.zfill(3)

    # Sets for logic
    merged_keys = {
        (_norm_beam(df["BEAM"].iloc[0]), str(df["Mouse_ID"].iloc[0]))
        for df in merged_beam_files.values()
    }

    approved_pairs = set()
    if isinstance(approved_files, (set, list, tuple)):
        for fname in approved_files:
            if fname in merged_beam_files:
                dfm = merged_beam_files[fname]
                approved_pairs.add((_norm_beam(dfm["BEAM"].iloc[0]), str(dfm["Mouse_ID"].iloc[0])))

    flagged_pairs = set()
    if isinstance(flagged_files, dict):
        for fname in flagged_files.keys():
            if fname in merged_beam_files:
                dfm = merged_beam_files[fname]
                flagged_pairs.add((_norm_beam(dfm["BEAM"].iloc[0]), str(dfm["Mouse_ID"].iloc[0])))

    wb = load_workbook(key_out_path)
    ws = wb.active  # first sheet

    # Header map
    header_to_col = {}
    for col in range(1, ws.max_column + 1):
        val = ws.cell(row=1, column=col).value
        if val is not None:
            header_to_col[str(val)] = col

    # Ensure target column exists
    def ensure_column(col_name: str) -> int:
        if col_name in header_to_col and header_to_col[col_name]:
            return header_to_col[col_name]
        new_idx = ws.max_column + 1
        ws.cell(row=1, column=new_idx).value = col_name
        header_to_col[col_name] = new_idx
        return new_idx

    red_fill    = PatternFill(fill_type="solid", start_color="FFFFC7CE", end_color="FFFFC7CE")
    orange_fill = PatternFill(fill_type="solid", start_color="FFFFE5B2", end_color="FFFFE5B2")

    respect_existing_fill = True
    def has_custom_fill(cell):
        pt = getattr(cell.fill, "patternType", None)
        if pt and pt != "none":
            rgb = getattr(cell.fill.start_color, "rgb", None)
            return rgb not in (None, "00000000")
        return False

    # We‚Äôll style BEAM_EX only when a tag applies; otherwise no fill.
    col_idx = ensure_column(session_col)
    key_iter = key_df.reset_index(drop=True)

    for i, row in key_iter.iterrows():
        excel_row = i + 2  # header row is 1
        pair = (row.get("BEAM_norm", None), str(row.get("Mouse_ID", "")))

        # Decide tag per your rules
        if pair not in merged_keys:
            tag = "no_file"              # RED
        elif pair in flagged_pairs and pair in approved_pairs:
            tag = "include"              # ORANGE (failed QC but manually included)
        elif pair in flagged_pairs and pair not in approved_pairs:
            tag = "reject"               # RED (manually rejected)
        else:
            tag = None                   # no color

        # Keep the existing text value in BEAM_EX
        val = row.get(session_col, pd.NA)
        cell = ws.cell(row=excel_row, column=col_idx)
        cell.value = "" if pd.isna(val) else str(val)

        # Apply fill only if a tag exists; otherwise, leave as-is (no color)
        if tag and not (respect_existing_fill and has_custom_fill(cell)):
            if tag in ("reject", "no_file"):
                cell.fill = red_fill
            elif tag == "include":
                cell.fill = orange_fill
        # If tag is None, do not touch cell.fill

    wb.save(key_out_path)
    print(f"\nStyled workbook saved to: {os.path.abspath(key_out_path)}")
except Exception as e:
    print(f"\n(Styling skipped) Could not style Excel: {e}")

# ---------- Downloads (Colab) ----------
if colab_files is not None:
    try:
        colab_files.download(key_out_path)
        colab_files.download(zip_out_path)
    except Exception as e:
        print(f"(Download hint) Could not auto-download: {e}\nFiles saved as:\n - {key_out_path}\n - {zip_out_path}")
else:
    print(f"(Local) Files saved as:\n - {key_out_path}\n - {zip_out_path}")