# Trial segmentation script

This notebook is intended for the breakdown, classification and analysis of the behavioural states shown in the video, starting frmo a trial segmentation. 

- `trial segmentation` takes the output of simplerCode.py (video + .csv), and segments the video based on mouse entry and exit times in the video rather than the .csv
- we then will manually sort the trials into `exploitative`, `explorative` and `nest`
- extract speed and trajectories per trial
- identify behavioural syllables with keypoint moseq 2D



In [None]:
# If needed, install deps in THIS kernel
%pip install -q numpy pandas opencv-python imageio-ffmpeg

# Expose a working ffmpeg binary (uses imageio-ffmpeg's embedded build)
import os, imageio_ffmpeg
os.environ.setdefault("IMAGEIO_FFMPEG_EXE", imageio_ffmpeg.get_ffmpeg_exe())
print("Using ffmpeg:", os.environ["IMAGEIO_FFMPEG_EXE"])


In [None]:
from pathlib import Path

# === REQUIRED: point to a good reference video (same rig/pose) ===
REFERENCE_VIDEO = r"C:\path\to\your\best_reference_session.mp4"  # <-- EDIT

# Where to save the ROIs you draw on the reference
REF_ROIS_CSV    = r"C:\path\to\rois_reference.csv"               # <-- EDIT

# Entrance + reward ROIs to draw (order matters)
ROI_NAMES = ["entrance1", "entrance2"] + [f"roi{i}" for i in range(1, 5)]

# Vision/segmentation params (tweak if needed)
THRESH_VALUE     = 160    # binary threshold for occupancy
THRESH_FACTOR    = 0.50   # ROI considered "occupied" if current white-sum < baseline*factor
MIN_DURATION_S   = 0.40   # drop bouts shorter than this
MERGE_GAP_S      = 0.20   # merge gaps shorter than this
PADDING_S        = 0.10   # ± seconds added around each bout when cutting
REENCODE         = True   # True=precise cuts; False=keyframe "copy"
PYR_SCALE        = 0.75   # 0.5–0.8 often robust; 1.0 = full-res feature matching
FALLBACK_TO_REF_ROIS = True  # if homography fails, use the reference ROIs directly

# Output naming
COLUMN_NAME = "video_segment_path"  # column added to *trial_info.csv


In [None]:
import os, shutil, subprocess
from typing import List, Tuple, Optional, Dict
import numpy as np
import pandas as pd
import cv2 as cv

# ---------- ffmpeg ----------
def ffmpeg_bin():
    return os.environ.get("IMAGEIO_FFMPEG_EXE") or "ffmpeg"

def ensure_ffmpeg():
    exe = ffmpeg_bin()
    ok = (Path(exe).exists() if (":" in exe or exe.endswith(".exe")) else shutil.which(exe) is not None)
    if not ok:
        raise RuntimeError(f"ffmpeg not found at '{exe}'. Install ffmpeg or set IMAGEIO_FFMPEG_EXE.")

# ---------- video helpers ----------
def open_video_any(path: str):
    for backend in (cv.CAP_MSMF, cv.CAP_FFMPEG, cv.CAP_DSHOW, cv.CAP_ANY):
        cap = cv.VideoCapture(path, backend)
        ok, frame = cap.read()
        if ok and frame is not None:
            cap.set(cv.CAP_PROP_POS_FRAMES, 0)
            return cap
        try: cap.release()
        except Exception: pass
    return None

def grab_first_frame(video_path: str) -> Optional[np.ndarray]:
    cap = open_video_any(video_path)
    if cap is None:
        return None
    ok, frame = cap.read()
    cap.release()
    return frame if ok else None

# ---------- ROI utils ----------
def save_rois_csv(dest_csv: Path, rows: List[Tuple[str,int,int,int,int]]) -> None:
    df = pd.DataFrame(rows, columns=["name","x","y","w","h"])
    dest_csv.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(dest_csv, index=False)

def load_rois_long(csv_path: str) -> List[Tuple[str,int,int,int,int]]:
    df = pd.read_csv(csv_path)
    df.columns = [c.lower() for c in df.columns]
    need = {"name","x","y","w","h"}
    if not need.issubset(df.columns):
        raise ValueError(f"ROI CSV must have columns {need}; got {set(df.columns)}")
    out = []
    for name, x, y, w, h in df[["name","x","y","w","h"]].itertuples(index=False, name=None):
        out.append((str(name).lower(), int(x), int(y), int(w), int(h)))
    return out

def overlay_rois(image: np.ndarray, rows: List[Tuple[str,int,int,int,int]], color=(0,255,0)) -> np.ndarray:
    out = image.copy()
    for name, x, y, w, h in rows:
        cv.rectangle(out, (x, y), (x+w, y+h), color, 2)
        cv.putText(out, name, (x, max(20, y-10)), cv.FONT_HERSHEY_SIMPLEX, 0.8, color, 2, cv.LINE_AA)
    return out

def draw_rois_on_first_frame(video_path: str, roi_names=ROI_NAMES, scale: float = 2.0
                             ) -> List[Tuple[str,int,int,int,int]]:
    frame = grab_first_frame(video_path)
    if frame is None:
        raise RuntimeError(f"Could not open/read: {video_path}")

    disp = cv.resize(frame, None, fx=scale, fy=scale, interpolation=cv.INTER_LINEAR)
    win = "Draw ROIs - ENTER to confirm, ESC to skip"
    cv.namedWindow(win, cv.WINDOW_NORMAL)
    cv.resizeWindow(win, disp.shape[1], disp.shape[0])
    try:
        cv.setWindowProperty(win, cv.WND_PROP_TOPMOST, 1)
    except Exception:
        pass

    selections: List[Tuple[str,int,int,int,int]] = []
    overlays: List[Tuple[str,int,int,int,int]] = []

    for name in roi_names:
        frame_show = disp.copy()
        for nm, sx, sy, sw, sh in overlays:
            cv.rectangle(frame_show, (sx, sy), (sx+sw, sy+sh), (255,0,0), 2)
            cv.putText(frame_show, nm, (sx, max(20, sy-10)), cv.FONT_HERSHEY_SIMPLEX, 0.8, (255,0,0), 2, cv.LINE_AA)
        cv.putText(frame_show, f"Draw {name} then ENTER (ESC to skip)", (20, 40),
                   cv.FONT_HERSHEY_SIMPLEX, 1.0, (0,0,255), 3, cv.LINE_AA)

        r = cv.selectROI(win, frame_show, fromCenter=False, showCrosshair=True)
        if r == (0,0,0,0):
            print(f"[WARN] Skipped ROI: {name}")
            continue
        x = int(r[0] / scale); y = int(r[1] / scale)
        w = int(r[2] / scale); h = int(r[3] / scale)
        selections.append((name.lower(), x, y, w, h))
        overlays.append((name, r[0], r[1], r[2], r[3]))

    cv.destroyAllWindows()
    if not selections:
        raise RuntimeError("No ROIs were drawn.")
    return selections

# ---------- Homography (ORB + BFMatcher + RANSAC) ----------
def detect_homography(ref_img: np.ndarray, cur_img: np.ndarray) -> Optional[np.ndarray]:
    g1 = cv.cvtColor(ref_img, cv.COLOR_BGR2GRAY) if ref_img.ndim == 3 else ref_img
    g2 = cv.cvtColor(cur_img, cv.COLOR_BGR2GRAY) if cur_img.ndim == 3 else cur_img
    orb = cv.ORB_create(nfeatures=5000, scaleFactor=1.2, edgeThreshold=15, patchSize=31)
    k1, d1 = orb.detectAndCompute(g1, None)
    k2, d2 = orb.detectAndCompute(g2, None)
    if d1 is None or d2 is None or len(k1) < 20 or len(k2) < 20:
        return None
    bf = cv.BFMatcher(cv.NORM_HAMMING, crossCheck=True)
    matches = bf.match(d1, d2)
    if len(matches) < 20:
        return None
    matches = sorted(matches, key=lambda m: m.distance)[:500]
    src = np.float32([k1[m.queryIdx].pt for m in matches]).reshape(-1,1,2)
    dst = np.float32([k2[m.trainIdx].pt for m in matches]).reshape(-1,1,2)
    H, mask = cv.findHomography(src, dst, cv.RANSAC, 3.0)
    if H is None: return None
    if mask is not None and int(mask.sum()) < 20:
        return None
    return H

def project_rect(x: int, y: int, w: int, h: int, H: np.ndarray) -> Tuple[int,int,int,int]:
    pts = np.array([[x,y], [x+w,y], [x+w,y+h], [x,y+h]], dtype=np.float32).reshape(-1,1,2)
    proj = cv.perspectiveTransform(pts, H).reshape(-1,2)
    xs, ys = proj[:,0], proj[:,1]
    x0 = max(0, int(np.floor(xs.min()))); y0 = max(0, int(np.floor(ys.min())))
    x1 = int(np.ceil(xs.max())); y1 = int(np.ceil(ys.max()))
    return x0, y0, max(1, x1-x0), max(1, y1-y0)

def auto_rois_from_reference(cur_img: np.ndarray,
                             ref_img: np.ndarray,
                             ref_rows: List[Tuple[str,int,int,int,int]],
                             pyr_scale: float = 0.75) -> Optional[List[Tuple[str,int,int,int,int]]]:
    if pyr_scale != 1.0:
        def S(a): return np.array([[a,0,0],[0,a,0],[0,0,1]], dtype=np.float32)
        ref_s = cv.resize(ref_img, None, fx=pyr_scale, fy=pyr_scale)
        cur_s = cv.resize(cur_img, None, fx=pyr_scale, fy=pyr_scale)
        Hs = detect_homography(ref_s, cur_s)
        if Hs is None: return None
        H = np.linalg.inv(S(pyr_scale)) @ Hs @ S(pyr_scale)
    else:
        H = detect_homography(ref_img, cur_img)
        if H is None: return None

    out = []
    for name, x, y, w, h in ref_rows:
        x2, y2, w2, h2 = project_rect(x, y, w, h, H)
        out.append((name, x2, y2, w2, h2))
    return out

# ---------- Entrance-based bout detection ----------
def _grab(frame, r):
    return frame[r["ystart"]:r["ystart"]+r["ylen"], r["xstart"]:r["xstart"]+r["xlen"]]

def compute_roi_baselines(cap, rois: Dict[str,Dict[str,int]], num_frames=10, thresh_value=160):
    thresholds = {k: 0.0 for k in rois}; n = 0
    pos = int(cap.get(cv.CAP_PROP_POS_FRAMES))
    for _ in range(num_frames):
        ok, frame = cap.read()
        if not ok: break
        gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) if frame.ndim == 3 else frame
        _, bw = cv.threshold(gray, thresh_value, 255, cv.THRESH_BINARY)
        for name, rect in rois.items():
            thresholds[name] += float(np.sum(_grab(bw, rect)))
        n += 1
    cap.set(cv.CAP_PROP_POS_FRAMES, pos)
    if n == 0:
        raise RuntimeError("Could not read frames to compute baselines.")
    for k in thresholds: thresholds[k] /= n
    return thresholds

def detect_bouts_by_entrances(
    video_path: str,
    rois_csv: str,
    thresh_value: int = 160,
    threshold_factor: float = 0.5,
    min_duration_s: float = 0.4,
    merge_gap_s: float = 0.2,
):
    df = pd.read_csv(rois_csv)
    df.columns = [c.lower() for c in df.columns]
    rois = {str(n).lower(): {"xstart":int(x), "ystart":int(y), "xlen":int(w), "ylen":int(h)}
            for n,x,y,w,h in df[["name","x","y","w","h"]].itertuples(index=False, name=None)}
    if "entrance1" not in rois or "entrance2" not in rois:
        raise ValueError("ROIs must include 'entrance1' and 'entrance2'.")

    cap = open_video_any(str(video_path))
    if cap is None:
        raise FileNotFoundError(f"Cannot open video: {video_path}")
    fps = cap.get(cv.CAP_PROP_FPS) or 0.0
    if fps <= 1e-3: fps = 30.0
    total_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT)) or None

    baselines = compute_roi_baselines(cap, rois, num_frames=10, thresh_value=thresh_value)
    def occupied(bw, name):
        return (np.sum(_grab(bw, rois[name])) < baselines[name] * threshold_factor)

    ent1_prev = False; ent2_prev = False
    hasLeft1 = False; hasLeft2 = False
    entered = False
    bouts = []
    cur_start = None

    frame_idx = 0
    cap.set(cv.CAP_PROP_POS_FRAMES, 0)
    while True:
        ok, frame = cap.read()
        if not ok: break
        gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) if frame.ndim == 3 else frame
        _, bw = cv.threshold(gray, thresh_value, 255, cv.THRESH_BINARY)

        e1_now = occupied(bw, "entrance1")
        e2_now = occupied(bw, "entrance2")

        left1 = (not e1_now) and ent1_prev
        left2 = (not e2_now) and ent2_prev

        if left1:
            hasLeft1 = True
            if hasLeft2 and entered:
                # EXIT (sequence 2->1)
                end_s = frame_idx / fps
                if cur_start is not None:
                    bouts.append((cur_start, end_s))
                cur_start = None
                entered = False
                hasLeft1 = hasLeft2 = False

        if left2:
            hasLeft2 = True
            if hasLeft1 and not entered:
                # ENTER (sequence 1->2)
                cur_start = frame_idx / fps
                entered = True

        ent1_prev, ent2_prev = e1_now, e2_now
        frame_idx += 1

    if entered and cur_start is not None:
        end_s = (total_frames / fps) if total_frames else (frame_idx / fps)
        bouts.append((cur_start, end_s))
    cap.release()

    # merge gaps and drop shorts
    if not bouts:
        return []
    merged = []
    s0, e0 = bouts[0]
    for s, e in bouts[1:]:
        if s - e0 <= merge_gap_s:
            e0 = e
        else:
            merged.append((s0, e0))
            s0, e0 = s, e
    merged.append((s0, e0))
    return [(s, e) for (s, e) in merged if (e - s) >= min_duration_s]

# ---------- Cut segments ----------
def cut_segment_ffmpeg(video, start_s, end_s, out_path, reencode=True):
    ensure_ffmpeg()
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    if reencode:
        cmd = [
            ffmpeg_bin(), "-hide_banner", "-loglevel", "error",
            "-ss", f"{start_s:.3f}", "-t", f"{max(0.0, end_s - start_s):.3f}",
            "-i", str(video),
            "-map", "0:v:0?", "-c:v", "libx264", "-preset", "veryfast", "-crf", "18",
            "-movflags", "+faststart", "-reset_timestamps", "1", str(out_path)
        ]
    else:
        cmd = [
            ffmpeg_bin(), "-hide_banner", "-loglevel", "error",
            "-ss", f"{start_s:.3f}", "-to", f"{end_s:.3f}",
            "-i", str(video),
            "-map", "0:v:0?", "-c", "copy",
            "-movflags", "+faststart", "-reset_timestamps", "1", str(out_path)
        ]
    subprocess.run(cmd, check=True)

def segment_video_by_detected_bouts(
    video_path: str,
    rois_csv: str,
    outdir: str,
    base_label: str,
    padding_s: float,
    reencode: bool,
    thresh_value: int,
    threshold_factor: float,
    min_duration_s: float,
    merge_gap_s: float,
):
    bouts = detect_bouts_by_entrances(
        video_path, rois_csv,
        thresh_value=thresh_value,
        threshold_factor=threshold_factor,
        min_duration_s=min_duration_s,
        merge_gap_s=merge_gap_s,
    )
    segments = []
    for i, (s, e) in enumerate(bouts):
        s2 = max(0.0, s - padding_s)
        e2 = max(s2, e + padding_s)
        out_name = f"{base_label}_trial_{i:03d}.mp4"
        out_path = Path(outdir) / out_name
        cut_segment_ffmpeg(video_path, s2, e2, out_path, reencode=reencode)
        segments.append({"trial_index": i, "start_s": s, "end_s": e, "path": str(out_path)})
    return segments

def update_trials_csv_with_paths(trials_csv: str, segments, column_name="video_segment_path", inplace=False):
    trials_csv = Path(trials_csv).resolve()
    df = pd.read_csv(trials_csv)
    paths = [seg["path"] for seg in segments]
    n = min(len(df), len(paths))
    if column_name not in df.columns:
        df[column_name] = pd.NA
    df.loc[:n-1, column_name] = paths[:n]
    updated = trials_csv if inplace else trials_csv.with_name(trials_csv.stem + "_with_segments_detected.csv")
    df.to_csv(updated, index=False)
    return str(updated), n, len(paths), len(df)

def build_outdir(video_path: Path):
    return video_path.parent / "segments_detected" / video_path.stem


In [None]:
# Draw once on your REFERENCE_VIDEO; saves to REF_ROIS_CSV
ref_rows = draw_rois_on_first_frame(REFERENCE_VIDEO, roi_names=ROI_NAMES, scale=2.0)
save_rois_csv(Path(REF_ROIS_CSV), ref_rows)
print("Saved reference ROIs ->", REF_ROIS_CSV)

# Quick visual check (press any key to close window)
first = grab_first_frame(REFERENCE_VIDEO)
vis = overlay_rois(first, ref_rows)
cv.imshow("Reference ROIs", vis); cv.waitKey(0); cv.destroyAllWindows()


In [None]:
# Paste or run your code that builds the {trial_csv: video} dict here.
# Example:
# csv_video = {"C:/.../mouse6359_session3.6_trial_info.csv": "C:/.../6359_2024-08-28_13_28_27s3.6.mp4", ...}

print("Pairs:", len(csv_video))
# Optional: sanity check the first few
for k, v in list(csv_video.items())[:3]:
    print("CSV:", k)
    print("VID:", v)


In [None]:
summaries = []

# Load reference once
ref_frame = grab_first_frame(REFERENCE_VIDEO)
ref_rows  = load_rois_long(REF_ROIS_CSV)

for trials_csv, video in csv_video.items():
    video_p = Path(video).resolve()
    trials_p = Path(trials_csv).resolve()

    if not video_p.exists():
        print(f"[SKIP] Video not found: {video_p}")
        continue
    if not trials_p.exists():
        print(f"[SKIP] CSV not found: {trials_p}")
        continue

    # Prepare per-session ROI CSV path
    session_rois_csv = video_p.with_suffix("").as_posix() + "_rois.csv"

    # If missing, adapt from reference via homography
    if not Path(session_rois_csv).exists():
        cur_frame = grab_first_frame(str(video_p))
        if cur_frame is None:
            print(f"[SKIP] Cannot open video for ROI adaptation: {video_p.name}")
            continue
        adapted = auto_rois_from_reference(cur_frame, ref_frame, ref_rows, pyr_scale=PYR_SCALE)
        if adapted is None:
            if FALLBACK_TO_REF_ROIS:
                print(f"[WARN] Homography failed for {video_p.name} -> using reference ROIs")
                rows = ref_rows
            else:
                print(f"[SKIP] Homography failed and fallback disabled: {video_p.name}")
                continue
        else:
            rows = adapted
        save_rois_csv(Path(session_rois_csv), rows)

        # Optional: save a preview image of the adapted ROIs
        try:
            preview = overlay_rois(cur_frame, rows)
            outdir = build_outdir(video_p)
            outdir.mkdir(parents=True, exist_ok=True)
            cv.imwrite(str(outdir / "rois_preview.png"), preview)
        except Exception as e:
            print("[WARN] Could not save ROI preview:", e)

    # Segment by detected bouts
    outdir = build_outdir(video_p)
    segs = segment_video_by_detected_bouts(
        video_path=str(video_p),
        rois_csv=session_rois_csv,
        outdir=str(outdir),
        base_label=video_p.stem,
        padding_s=PADDING_S,
        reencode=REENCODE,
        thresh_value=THRESH_VALUE,
        threshold_factor=THRESH_FACTOR,
        min_duration_s=MIN_DURATION_S,
        merge_gap_s=MERGE_GAP_S,
    )

    # Update the CSV with paths (row order -> segment order)
    updated_csv_path, filled, n_segments, n_rows = update_trials_csv_with_paths(
        trials_csv=str(trials_p),
        segments=segs,
        column_name=COLUMN_NAME,
        inplace=False,  # write *_with_segments_detected.csv
    )

    summaries.append({
        "video": str(video_p),
        "roi_csv": session_rois_csv,
        "segments_outdir": str(outdir),
        "segments_made": len(segs),
        "csv_original": str(trials_p),
        "csv_updated": updated_csv_path,
        "rows_filled": filled,
        "csv_rows": n_rows
    })

print("Done. Sessions processed:", len(summaries))


In [None]:
import pandas as pd
pd.DataFrame(summaries)
