In [None]:
# Cell 1 - Setup & imports
import os, json, glob, math, random, joblib
from pathlib import Path
import numpy as np
import pandas as pd

ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

ISIT_ROOT = os.path.join(ROOT, "data", "raw", "ISIT_Dataset")
BALABIT_ROOT = os.path.join(ROOT, "data", "raw", "Balabit_dataset_refined")
OUT_DIR = os.path.join(ROOT, "data", "processed")


os.makedirs(OUT_DIR, exist_ok=True)

# CPU / TF thread tuning (safe for notebook)
import tensorflow as tf
seed = 42
np.random.seed(seed); random.seed(seed); tf.random.set_seed(seed)
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['OMP_NUM_THREADS'] = '6'
os.environ['MKL_NUM_THREADS'] = '6'
tf.config.threading.set_intra_op_parallelism_threads(6)
tf.config.threading.set_inter_op_parallelism_threads(2)

print("Paths set. OUT_DIR:", OUT_DIR)

In [None]:
# Replacement loader cell for ISIT + Balabit (handles ISIT.trace format and Balabit CSVs w/o timestamps)
import os, glob, json, pandas as pd, math
from typing import List, Tuple, Any
from pathlib import Path

# where your roots were set previously
# ISIT_ROOT, BALABIT_ROOT must already be defined in Cell 1
SEARCH_EXTS = (".json", ".jsonl", ".ndjson", ".csv", ".txt", ".log")

def find_files_recursive(root, exts=SEARCH_EXTS):
    files = []
    for ext in exts:
        files += glob.glob(os.path.join(root, "**", f"*{ext}"), recursive=True)
    files = sorted(list(dict.fromkeys(files)))
    return files

def parse_isit_json_obj(obj):
    """
    Parse ISIT JSON structure example seen in your logs:
    {
      "user_id": "...",
      "device": "desktop",
      "trace": [
        {"event_name":"load","timestamp":"1695...","position":{"x":0,"y":0}},
        {"event_name":"mousemove","timestamp":"1695...","position":{"x":123,"y":456}},
        ...
      ],
      ...maybe other keys...
    }
    Returns a list of {"x", "y", "t"} events or None if not found.
    """
    if not isinstance(obj, dict):
        return None
    trace = None
    for k in ("trace", "events", "mouse", "traj"):
        if k in obj:
            trace = obj[k]; break
    if trace is None or not isinstance(trace, list):
        return None
    out = []
    for e in trace:
        # timestamp might be string numeric ms or nested; position might be nested
        t = None
        if isinstance(e, dict):
            # some traces use 'timestamp' or 'time' or 'ts'
            for tk in ("timestamp","time","ts","t"):
                if tk in e:
                    t = e.get(tk); break
            # position might be in e['position'] or e.get('pos')
            pos = None
            if "position" in e and isinstance(e["position"], dict):
                pos = e["position"]
            elif "pos" in e and isinstance(e["pos"], dict):
                pos = e["pos"]
            else:
                # sometimes coordinates are top-level keys 'x','y'
                pos = e
            # read x,y if available
            try:
                x = pos.get("x", pos.get("X")) if isinstance(pos, dict) else None
                y = pos.get("y", pos.get("Y")) if isinstance(pos, dict) else None
            except Exception:
                x = None; y = None
            # fallback: e may be like {'clientX':..}
            if x is None or y is None:
                for cand in ("clientX","clientY","pageX","pageY"):
                    if cand in e:
                        if cand.endswith("X"): x = e[cand]
                        if cand.endswith("Y"): y = e[cand]
        else:
            continue
        # convert timestamp to float ms if string
        try:
            if t is None:
                # some traces have 'timeOrigin' at top-level or events may not have timestamps
                t = e.get("time") if isinstance(e, dict) else None
            if isinstance(t, str) and t.isdigit():
                t = float(t)
            else:
                try:
                    t = float(t)
                except:
                    t = None
        except Exception:
            t = None
        # if coordinates present and numeric, append
        try:
            x_f = float(x)
            y_f = float(y)
            if t is None:
                # if no timestamp, we'll append None and caller can synthesize later
                out.append({"x": x_f, "y": y_f, "t": None})
            else:
                out.append({"x": x_f, "y": y_f, "t": float(t)})
        except Exception:
            continue
    return out if out else None

def parse_balabit_csv(path):
    """
    Balabit CSVs appear to contain only x,y per row (no timestamp).
    We will read x,y and create a synthetic timestamp t = index * delta_ms.
    delta_ms default = 16ms (~60Hz); you can adjust if you know sampling rate.
    """
    try:
        df = pd.read_csv(path, header=0)
    except Exception:
        return None
    # identify x,y columns (names may vary)
    xcol = None; ycol = None
    for c in df.columns:
        cl = c.lower()
        if cl in ("x","posx","clientx","pagex"): xcol = c
        if cl in ("y","posy","clienty","pagey"): ycol = c
    # if only two unnamed columns, treat them as x,y
    if xcol is None or ycol is None:
        # try first two numeric columns
        numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
        if len(numeric_cols) >= 2:
            xcol, ycol = numeric_cols[0], numeric_cols[1]
        else:
            # fallback: if there are exactly 2 columns, use them
            if len(df.columns) == 2:
                xcol, ycol = df.columns[0], df.columns[1]
            else:
                return None
    xs = df[xcol].astype(float).tolist()
    ys = df[ycol].astype(float).tolist()
    n = len(xs)
    # create synthetic timestamps: delta_ms can be tuned. 16 ms ~ 60Hz; 10ms faster sampling.
    delta_ms = 16.0
    # If there's an index column that looks like timestamp, detect it:
    possible_time = None
    for c in df.columns:
        cl = c.lower()
        if "time" in cl or "timestamp" in cl or "ts" == cl:
            possible_time = c; break
    if possible_time:
        ts = df[possible_time].astype(float).tolist()
        events = [{"x": float(xs[i]), "y": float(ys[i]), "t": float(ts[i]) if not (pd.isna(ts[i])) else i*delta_ms} for i in range(n)]
    else:
        events = [{"x": float(xs[i]), "y": float(ys[i]), "t": i*delta_ms} for i in range(n)]
    return events

def load_sessions_from_root(root):
    files = find_files_recursive(root)
    print(f"Scanning {len(files)} files under {root}")
    parsed = []
    for f in files:
        ext = Path(f).suffix.lower()
        if ext in (".json", ".jsonl", ".ndjson"):
            try:
                with open(f, "r", encoding="utf-8", errors="ignore") as fh:
                    txt = fh.read()
                # try parse JSON or JSON lines
                try:
                    obj = json.loads(txt)
                    # If obj is a dict matching ISIT structure -> parse trace
                    if isinstance(obj, dict):
                        evs = parse_isit_json_obj(obj)
                        if evs:
                            parsed.append((evs, None, f))
                            continue
                        # else maybe it's a list of sessions
                        # fall through to handle below
                    if isinstance(obj, list):
                        # each item may be a session or event list
                        for item in obj:
                            if isinstance(item, dict):
                                evs = parse_isit_json_obj(item) or (item.get("events") or item.get("mouse") or item.get("trace"))
                                if evs:
                                    parsed.append((evs, None, f))
                        continue
                except Exception:
                    # try jsonlines
                    items = []
                    for line in txt.splitlines():
                        line = line.strip()
                        if not line: continue
                        try:
                            items.append(json.loads(line))
                        except:
                            continue
                    for item in items:
                        evs = parse_isit_json_obj(item) or (item.get("events") if isinstance(item, dict) else None)
                        if evs:
                            parsed.append((evs, None, f))
                            continue
            except Exception as e:
                # skip file
                continue
        elif ext == ".csv":
            evs = parse_balabit_csv(f)
            if evs:
                parsed.append((evs, None, f))
                continue
        else:
            # try to open and see if it contains JSON lines
            try:
                with open(f, "r", encoding="utf-8", errors="ignore") as fh:
                    txt = fh.read()
                items = []
                try:
                    items = json.loads(txt)
                except:
                    for line in txt.splitlines():
                        line = line.strip()
                        if not line: continue
                        try:
                            items.append(json.loads(line))
                        except:
                            continue
                for item in items:
                    evs = parse_isit_json_obj(item) or (item.get("events") if isinstance(item, dict) else None)
                    if evs:
                        parsed.append((evs, None, f))
            except:
                continue
    print(f"Raw parsed session-like items: {len(parsed)}")
    # show some examples and infer labels from path heuristics
    samples = parsed[:8]
    for i, (evs, lbl, p) in enumerate(samples):
        print(f"[{i}] path={p} events={len(evs)} sample_event={evs[0] if evs else None}")
    # infer labels from path heuristics and return normalized sessions
    def infer_label_from_path(p):
        pl = p.lower()
        if any(k in pl for k in ("gremlin","gremlins","za_proxy","random_mouse","sleep_bot","fake","fake_data","bot","proxy")):
            return "bot"
        if any(k in pl for k in ("hlisa","survey","desktop","training_files","test_files","balabit","real")):
            return "human"
        # default conservative = human
        return "human"
    out_sessions = []
    for evs, lbl, p in parsed:
        # evs is list of dicts possibly with 't' None for some ISIT entries; we will synthesize if needed
        # ensure events are dicts with numeric x,y,t; convert timestamp strings to float; if t missing, set to None (synth later)
        normalized = []
        for e in evs:
            try:
                if isinstance(e, dict):
                    x = e.get("x") if "x" in e else (e.get("position",{}).get("x") if isinstance(e.get("position"), dict) else None)
                    y = e.get("y") if "y" in e else (e.get("position",{}).get("y") if isinstance(e.get("position"), dict) else None)
                    t = e.get("t") if "t" in e else e.get("timestamp", e.get("time", None))
                    # sometimes t is string digits: convert
                    if isinstance(t, str) and t.isdigit():
                        t = float(t)
                    try:
                        x_f = float(x)
                        y_f = float(y)
                    except:
                        # skip if coords missing
                        continue
                    if t is None:
                        normalized.append({"x": x_f, "y": y_f, "t": None})
                    else:
                        try:
                            t_f = float(t)
                        except:
                            t_f = None
                        normalized.append({"x": x_f, "y": y_f, "t": t_f})
                elif isinstance(e, (list, tuple)) and len(e) >= 2:
                    x_f = float(e[0]); y_f = float(e[1])
                    t_f = float(e[2]) if len(e) >= 3 else None
                    normalized.append({"x": x_f, "y": y_f, "t": t_f})
            except Exception:
                continue
        if len(normalized) < 3:
            continue
        label = lbl if lbl is not None else infer_label_from_path(p)
        out_sessions.append((normalized, label, p))
    print(f"Normalized usable sessions: {len(out_sessions)}")
    return out_sessions

# Run for both roots
print("Parsing ISIT root...")
isit_sessions = load_sessions_from_root(ISIT_ROOT)
print("Parsing Balabit root...")
balabit_sessions = load_sessions_from_root(BALABIT_ROOT)

print(f"ISIT usable sessions: {len(isit_sessions)}")
print(f"Balabit usable sessions: {len(balabit_sessions)}")

# Combine
all_sessions = isit_sessions + balabit_sessions
print("TOTAL usable sessions:", len(all_sessions))

# At this point some events may have 't': None. We will synthesize missing timestamps (per-session)
def synthesize_timestamps_if_needed(events, default_delta_ms=16.0):
    # If any event has t == None, create monotonic timestamps using provided deltas
    if any(e.get("t") is None for e in events):
        tvals = []
        # If some events have t and some not, try to anchor; otherwise use index*delta
        existing = [e.get("t") for e in events if e.get("t") is not None]
        if existing:
            # anchor first known as offset; fill gaps using delta
            # find first index where t present
            first_idx = next(i for i,e in enumerate(events) if e.get("t") is not None)
            # build timestamps relative to that
            for i in range(len(events)):
                if i <= first_idx:
                    tvals.append(events[first_idx]["t"] - (first_idx - i) * default_delta_ms)
                else:
                    prev = tvals[-1]
                    tvals.append(prev + default_delta_ms)
        else:
            tvals = [i * default_delta_ms for i in range(len(events))]
        # assign
        for i, e in enumerate(events):
            e["t"] = float(tvals[i])
    # else: ensure t are floats
    for e in events:
        if not isinstance(e.get("t"), (int,float)):
            try:
                e["t"] = float(e["t"])
            except:
                e["t"] = 0.0
    return events

# Apply synthesis for all_sessions
all_sessions_syn = []
for evs, lbl, p in all_sessions:
    evs2 = synthesize_timestamps_if_needed(evs, default_delta_ms=16.0)
    all_sessions_syn.append((evs2, lbl, p))

# Replace global all_sessions variable used by later cells
all_sessions = all_sessions_syn

print("Final session count available for downstream processing:", len(all_sessions))
# print a small sample
for i, (evs, lbl, p) in enumerate(all_sessions[:5]):
    print(f"[SAMPLE {i}] label={lbl} path={p} n_events={len(evs)} first_event={evs[0]}")

In [7]:
# Cell A - Extract per-window feature vectors (20 features) and show names
import numpy as np, math, os, joblib
from tqdm import tqdm

WINDOW = 10
STRIDE = max(1, WINDOW // 2)

# feature extractor (same as before)
def extract_features_array(events):
    xs=[]; ys=[]; ts=[]
    for e in events:
        xs.append(float(e["x"])); ys.append(float(e["y"])); ts.append(float(e["t"]))
    xs = np.array(xs); ys = np.array(ys); ts = np.array(ts)
    if (np.diff(ts) < 0).any():
        ts = np.cumsum(np.clip(np.diff(ts, prepend=ts[0]), 1.0, None))
    if len(xs) < 3:
        return np.zeros(20, dtype=float)
    dx = np.diff(xs); dy = np.diff(ys); dt = np.diff(ts)
    dt = np.where(dt == 0, 1.0, dt)
    vx = dx/dt; vy = dy/dt
    speed = np.sqrt(vx**2 + vy**2)
    acc = np.diff(speed) if speed.size>1 else np.array([0.0])
    angles=[]
    for i in range(1, len(dx)):
        x1,y1=dx[i-1],dy[i-1]; x2,y2=dx[i],dy[i]
        a1=math.atan2(y1,x1); a2=math.atan2(y2,x2)
        da=a2-a1
        while da<=-math.pi: da+=2*math.pi
        while da>math.pi: da-=2*math.pi
        angles.append(da)
    angles = np.array(angles) if len(angles)>0 else np.array([0.0])
    pause_thresh = np.percentile(dt,75)*1.5
    pause_frac = float((dt > pause_thresh).sum())/max(1,len(dt))
    width = xs.max() - xs.min() if xs.size else 0.0
    height= ys.max() - ys.min() if ys.size else 0.0
    bbox_aspect = float(width/height) if height != 0 else 0.0
    path_len = float(np.sum(np.sqrt(dx*dx + dy*dy)))
    feats = [
        float(np.mean(speed)), float(np.std(speed)), float(np.max(speed)),
        float(np.mean(acc)), float(np.std(acc)), float(np.max(acc)),
        float(np.mean(np.abs(dx))), float(np.std(dx)),
        float(np.mean(np.abs(dy))), float(np.std(dy)),
        float(np.mean(angles)), float(np.std(angles)),
        float(pause_frac),
        float(bbox_aspect),
        float(path_len),
        float(np.percentile(speed,25)),
        float(np.percentile(speed,50)),
        float(np.percentile(speed,75)),
        float(np.median(dt)),
        float(len(xs))
    ]
    return np.array(feats, dtype=float)

feature_names = [
    "mean_speed", "std_speed", "max_speed",
    "mean_acc", "std_acc", "max_acc",
    "mean_abs_dx", "std_dx",
    "mean_abs_dy", "std_dy",
    "mean_angles", "std_angles",
    "pause_frac",
    "bbox_aspect",
    "path_len",
    "speed_p25", "speed_p50", "speed_p75",
    "median_dt",
    "n_events"
]

# Build per-window arrays and keep session->windows mapping for later
X_windows = []
y_windows = []
session_windows = []  # list of (list_of_window_features, label, path)

for evs, label, path in tqdm(all_sessions, desc="Extract sessions"):
    feats_list = []
    for i in range(0, max(1, len(evs) - WINDOW + 1), STRIDE):
        w = evs[i:i+WINDOW]
        f = extract_features_array(w)
        feats_list.append(f)
        X_windows.append(f)
        y_windows.append(1 if label == "bot" else 0)
    if len(feats_list) == 0:
        f = extract_features_array(evs)
        feats_list.append(f)
        X_windows.append(f)
        y_windows.append(1 if label == "bot" else 0)
    session_windows.append((feats_list, 1 if label == "bot" else 0, path))

X_windows = np.vstack(X_windows)
y_windows = np.array(y_windows)
print("Per-window shape:", X_windows.shape, "labels:", np.bincount(y_windows))
print("All 20 feature names:\n", feature_names)

Extract sessions: 100%|██████████| 3143/3143 [23:16<00:00,  2.25it/s]  


Per-window shape: (1617919, 20) labels: [884789 733130]
All 20 feature names:
 ['mean_speed', 'std_speed', 'max_speed', 'mean_acc', 'std_acc', 'max_acc', 'mean_abs_dx', 'std_dx', 'mean_abs_dy', 'std_dy', 'mean_angles', 'std_angles', 'pause_frac', 'bbox_aspect', 'path_len', 'speed_p25', 'speed_p50', 'speed_p75', 'median_dt', 'n_events']


In [8]:
# Cell B - Feature selection
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

dfX = pd.DataFrame(X_windows, columns=feature_names)
dfy = pd.Series(y_windows, name="label")

# 1) low variance removal
low_var_thresh = 1e-6
low_var = dfX.var()[dfX.var() <= low_var_thresh].index.tolist()
print("Low variance features removed:", low_var)

# 2) correlation-based candidates
corr = dfX.corr().abs()
upper = corr.where(np.triu(np.ones(corr.shape), k=1).astype(bool))
high_corr_pairs = [(col, idx) for col in upper.columns for idx in upper.index if upper.loc[idx,col] > 0.95]
print("High-correlation feature pairs (>0.95): (showing up to 10)", high_corr_pairs[:10])

# 3) quick RF importance
Xtr_tmp, Xte_tmp, ytr_tmp, yte_tmp = train_test_split(dfX, dfy, test_size=0.2, stratify=dfy, random_state=42)
sc = StandardScaler().fit(Xtr_tmp)
Xtr_s = sc.transform(Xtr_tmp)
rf_tmp = RandomForestClassifier(n_estimators=200, class_weight='balanced', n_jobs=6, random_state=42)
rf_tmp.fit(Xtr_s, ytr_tmp)
importances = pd.Series(rf_tmp.feature_importances_, index=dfX.columns).sort_values(ascending=False)
print("Top feature importances:\n", importances.head(12))

# remove near-zero importance features
importance_thresh = 0.003  # be conservative
low_imp = importances[importances <= importance_thresh].index.tolist()
print("Low importance features removed:", low_imp)

# Build initial keep set
keep = set(feature_names) - set(low_var) - set(low_imp)

# For each high-correlation pair, drop the lower-importance one
for i in range(len(feature_names)):
    for j in range(i+1, len(feature_names)):
        f1 = feature_names[i]; f2 = feature_names[j]
        if upper.loc[f1, f2] > 0.95:
            # keep the one with higher importance
            if importances[f1] >= importances[f2]:
                if f2 in keep: keep.remove(f2)
            else:
                if f1 in keep: keep.remove(f1)

selected_features = [f for f in feature_names if f in keep]
print("FINAL selected features (count={}):".format(len(selected_features)), selected_features)

# Save selected features
import json
json.dump({"selected_features": selected_features}, open(os.path.join(OUT_DIR, "mouse_selected_features.json"), "w"), indent=2)
print("Saved selected feature list to mouse_selected_features.json")

Low variance features removed: ['n_events']
High-correlation feature pairs (>0.95): (showing up to 10) [('std_speed', 'mean_speed'), ('max_speed', 'mean_speed'), ('max_speed', 'std_speed'), ('std_acc', 'mean_speed'), ('std_acc', 'std_speed'), ('std_acc', 'max_speed'), ('max_acc', 'std_acc'), ('std_dx', 'mean_abs_dx'), ('std_dy', 'mean_abs_dy'), ('path_len', 'mean_abs_dy')]
Top feature importances:
 median_dt      0.226451
speed_p25      0.163722
mean_speed     0.138580
speed_p50      0.086216
max_speed      0.083242
std_speed      0.078602
speed_p75      0.042191
std_acc        0.032785
mean_abs_dy    0.027059
path_len       0.025292
max_acc        0.024747
std_dy         0.024086
dtype: float64
Low importance features removed: ['pause_frac', 'mean_angles', 'bbox_aspect', 'n_events']
FINAL selected features (count=9): ['mean_speed', 'mean_acc', 'std_dx', 'mean_abs_dy', 'std_angles', 'speed_p25', 'speed_p50', 'speed_p75', 'median_dt']
Saved selected feature list to mouse_selected_featur

In [None]:
# Rebuild session_windows from raw dataset files (drop-in)
import os, json, csv, sys, traceback
from pathlib import Path
from collections import defaultdict
import numpy as np

# ---------- CONFIG ----------
OUT_DIR = os.environ.get("OUT_DIR", os.path.join(os.getcwd(), "data", "processed"))
os.makedirs(OUT_DIR, exist_ok=True)

# Where to search for datasets (add more paths here if your files live elsewhere)
candidates = []
# allow explicit env var override
if os.environ.get("DATASET_ROOT"):
    candidates.append(os.environ["DATASET_ROOT"])
# common local locations
candidates += [
    os.path.join(os.path.expanduser("~"), "Downloads"),
    os.path.join(os.path.expanduser("~"), "Downloads", "ISIT_Dataset"),
    os.path.join(os.path.expanduser("~"), "Downloads", "archive"),
    os.path.join(os.getcwd(), "data", "raw"),
    os.path.join(os.getcwd(), "..", "data", "raw"),
    os.path.join(os.getcwd(), "..", "datasets"),
    os.path.join(os.getcwd(), "datasets"),
]
# dedupe & keep existing paths only
search_roots = [os.path.abspath(p) for p in dict.fromkeys(candidates) if p and os.path.exists(p)]
if not search_roots:
    # fallback to current working directory
    search_roots = [os.getcwd()]

print("Searching dataset roots:", search_roots)

# windowing parameters (events per window)
EVENT_WINDOW = int(os.environ.get("EVENT_WINDOW", 40))   # number of raw events per feature window
EVENT_STRIDE  = int(os.environ.get("EVENT_STRIDE", 20))  # stride in events
MIN_EVENTS_PER_SESSION = 10   # if fewer than this, skip

# helper: parse CSV (Balabit) expecting rows x,y[,t]
def parse_csv_events(path):
    ev = []
    try:
        with open(path, "r", newline='', encoding="utf-8", errors="ignore") as fh:
            reader = csv.reader(fh)
            # try to skip header if non-numeric
            for row in reader:
                if not row:
                    continue
                # keep only numeric columns
                nums = []
                for r in row:
                    try:
                        nums.append(float(r))
                    except:
                        pass
                if len(nums) >= 2:
                    x = float(nums[0]); y = float(nums[1])
                    t = float(nums[2]) if len(nums) >= 3 else None
                    d = {"x": x, "y": y}
                    if t is not None: d["t"] = t
                    ev.append(d)
    except Exception:
        traceback.print_exc()
    return ev

# helper: parse JSON files from ISIT-like structure or generic arrays
def parse_json_events(path):
    ev = []
    try:
        with open(path, "r", encoding="utf-8", errors="ignore") as fh:
            j = json.load(fh)
            # common patterns:
            # 1) top-level has 'trace' list of dicts with position/time
            if isinstance(j, dict):
                # many ISIT files have 'trace' or 'events'
                for key in ("trace", "events", "eventsList", "data"):
                    if key in j and isinstance(j[key], list):
                        for item in j[key]:
                            # item may be {'position': {'x':..,'y':..}, 'timestamp':...}
                            if isinstance(item, dict):
                                # nested position
                                pos = None
                                if "position" in item and isinstance(item["position"], dict):
                                    pos = item["position"]
                                elif "pos" in item and isinstance(item["pos"], dict):
                                    pos = item["pos"]
                                elif "x" in item and "y" in item:
                                    pos = item
                                if pos:
                                    x = pos.get("x") or pos.get("X") or pos.get("clientX") or pos.get("pageX")
                                    y = pos.get("y") or pos.get("Y") or pos.get("clientY") or pos.get("pageY")
                                    t = item.get("timestamp") or item.get("time") or item.get("ts") or item.get("t")
                                    try:
                                        if x is None or y is None:
                                            continue
                                        d = {"x": float(x), "y": float(y)}
                                        if t is not None:
                                            d["t"] = float(t)
                                        ev.append(d)
                                    except Exception:
                                        continue
                                else:
                                    # sometimes item itself is [x,y,t] or [x,y]
                                    if isinstance(item, list) and len(item) >= 2:
                                        try:
                                            x=float(item[0]); y=float(item[1]); t=float(item[2]) if len(item)>2 else None
                                            d={"x":x,"y":y}
                                            if t is not None: d["t"]=t
                                            ev.append(d)
                                        except:
                                            continue
                        if ev:
                            return ev
                # pattern: file itself is {'user_id':..., 'trace': [{'event_name':..., 'position': {'x':..,'y':..}, 'timestamp':...}, ...]}
                # if fallback: scan entire dict for lists of dicts with x,y
                def scan_for_list(o):
                    if isinstance(o, dict):
                        for v in o.values():
                            res = scan_for_list(v)
                            if res:
                                return res
                    elif isinstance(o, list):
                        # check if list contains dicts with x,y
                        ok = True
                        found = []
                        for item in o:
                            if isinstance(item, dict) and (("x" in item and "y" in item) or ("position" in item)):
                                # reuse logic above
                                if "position" in item and isinstance(item["position"], dict):
                                    p = item["position"]
                                    x = p.get("x"); y = p.get("y"); t = item.get("timestamp") or item.get("time") or None
                                    if x is not None and y is not None:
                                        try:
                                            d={"x":float(x),"y":float(y)}
                                            if t is not None: d["t"]=float(t)
                                            found.append(d)
                                        except:
                                            pass
                                elif "x" in item and "y" in item:
                                    try:
                                        d={"x":float(item["x"]),"y":float(item["y"])}
                                        if "t" in item: d["t"]=float(item["t"])
                                        found.append(d)
                                    except:
                                        pass
                            elif isinstance(item, list) and len(item) >= 2:
                                try:
                                    x=float(item[0]); y=float(item[1]); t=float(item[2]) if len(item)>2 else None
                                    d={"x":x,"y":y}
                                    if t is not None: d["t"]=t
                                    found.append(d)
                                except:
                                    pass
                            else:
                                # nested lists
                                res = scan_for_list(item)
                                if res:
                                    return res
                        if found:
                            return found
                    return None
                res = scan_for_list(j)
                if res:
                    return res
            elif isinstance(j, list):
                # list of events
                for item in j:
                    if isinstance(item, dict):
                        if "x" in item and "y" in item:
                            try:
                                d={"x":float(item["x"]),"y":float(item["y"])}
                                if "t" in item: d["t"]=float(item["t"])
                                ev.append(d)
                            except:
                                pass
                        elif "position" in item and isinstance(item["position"], dict):
                            p=item["position"]
                            try:
                                d={"x":float(p.get("x")),"y":float(p.get("y"))}
                                if "timestamp" in item: d["t"]=float(item.get("timestamp"))
                                ev.append(d)
                            except:
                                pass
                    elif isinstance(item, list) and len(item)>=2:
                        try:
                            x=float(item[0]); y=float(item[1]); t=float(item[2]) if len(item)>2 else None
                            d={"x":x,"y":y}
                            if t is not None: d["t"]=t
                            ev.append(d)
                        except:
                            pass
            # final fallback: nothing matched
    except Exception:
        traceback.print_exc()
    return ev

# import feature extractor from your project (use relative import; ensure project root in sys.path)
try:
    # try importing as package
    from backend.mouse_model import extract_features_from_events
    print("Imported extract_features_from_events from backend.mouse_model")
except Exception:
    try:
        # try adding parent path and importing
        sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..")))
        from backend.mouse_model import extract_features_from_events
        print("Imported extract_features_from_events after adjusting sys.path")
    except Exception:
        traceback.print_exc()
        raise RuntimeError("Could not import backend.mouse_model.extract_features_from_events. Ensure your repo is present and PYTHONPATH includes project root.")

# Walk dataset roots and collect candidate files
found_files = []
for root in search_roots:
    for dirpath, dirnames, filenames in os.walk(root):
        for fn in filenames:
            if fn.lower().endswith(".csv") or fn.lower().endswith(".json"):
                found_files.append(os.path.join(dirpath, fn))
print("Found candidate files:", len(found_files))

# Heuristic label detection from filename/path
def detect_label_from_path(p):
    p_lower = p.lower()
    bot_keywords = ("bot","gremlin","fake","robot","attack","malicious","automation","selenium","botnet")
    for k in bot_keywords:
        if k in p_lower:
            return 1
    # fallback: if path contains 'human' or 'real' treat as human
    human_keywords = ("human","real","benign","genuine")
    for k in human_keywords:
        if k in p_lower:
            return 0
    return 0

session_windows = []   # list of (feats_list, label, path)
skipped = 0
processed = 0
for fp in found_files:
    try:
        events = []
        if fp.lower().endswith(".csv"):
            events = parse_csv_events(fp)
        else:
            events = parse_json_events(fp)
        if not events or len(events) < MIN_EVENTS_PER_SESSION:
            skipped += 1
            continue

        # build overlapping event-chunks -> each chunk will be passed to feature extractor
        n = len(events)
        i = 0
        feats_list = []
        while i < n:
            chunk = events[i:i+EVENT_WINDOW]
            if len(chunk) >= max(3, min(EVENT_WINDOW, 3)):  # require at least some events
                try:
                    fv = extract_features_from_events(chunk)
                    # if extractor returns None or invalid, skip
                    if fv is None:
                        pass
                    else:
                        feats_list.append(fv)
                except Exception:
                    # extractor might expect x,y,t - fallback: try to coerce dicts
                    try:
                        fv = extract_features_from_events([{"x":e.get("x"), "y":e.get("y"), "t":e.get("t")} for e in chunk])
                        if fv is not None:
                            feats_list.append(fv)
                    except Exception:
                        pass
            i += EVENT_STRIDE
        if not feats_list:
            skipped += 1
            continue

        label = detect_label_from_path(fp)
        session_windows.append((feats_list, label, fp))
        processed += 1
        if processed % 200 == 0:
            print("Processed sessions:", processed, "skipped:", skipped)
    except Exception:
        traceback.print_exc()
        skipped += 1
        continue

print("Done. Sessions created:", len(session_windows), "skipped files:", skipped)

# Save session_windows to disk (so later cells can load after restart)
# We'll save a compact representation: list of dicts with small arrays converted to lists (to keep np.save usable)
simple_sessions = []
for feats_list, label, path in session_windows:
    # feats_list is list of lists/floats; convert to python lists
    fl = [list(map(float, np.asarray(x).tolist())) for x in feats_list]
    simple_sessions.append({"features": fl, "label": int(label), "path": path})

# Save as npy (pickled) and JSON summary
np_save_path = os.path.join(OUT_DIR, "session_windows.npy")
json_save_path = os.path.join(OUT_DIR, "session_windows.json")
np.save(np_save_path, simple_sessions, allow_pickle=True)
with open(json_save_path, "w", encoding="utf-8") as fh:
    json.dump({"n_sessions": len(simple_sessions), "sample": simple_sessions[:5]}, fh, indent=2)

print("Saved session_windows: npy ->", np_save_path, " json ->", json_save_path)

# Also restore in-memory variable for immediate use in notebook
# convert back to same structure (feats_list, label, path)
session_windows = [(sess["features"], sess["label"], sess["path"]) for sess in simple_sessions]
print("session_windows ready in memory with", len(session_windows), "sessions (each contains per-window feature lists).")

In [None]:
# Cell C (patched) - Session-level train/test split, uses canonical feature names when needed
from sklearn.model_selection import train_test_split
import numpy as np
import os, json, traceback

# ---------- CONFIG ----------
# If you prefer the project-root data/processed, set OUT_DIR env var before running.
OUT_DIR = os.environ.get("OUT_DIR", os.path.join(os.getcwd(), "data", "processed"))
os.makedirs(OUT_DIR, exist_ok=True)
SEQ_LEN = int(os.environ.get("SEQ_LEN", 8))
TEST_SIZE = float(os.environ.get("TEST_SIZE", 0.20))
RANDOM_STATE = int(os.environ.get("RANDOM_STATE", 42))
STRIDE = int(os.environ.get("SEQ_STRIDE", 1))

print("OUT_DIR:", OUT_DIR, "SEQ_LEN:", SEQ_LEN, "TEST_SIZE:", TEST_SIZE)

# ---------- canonical feature names (order must match extract_features_from_events) ----------
canonical_feature_names = [
    "mean_speed", "std_speed", "max_speed",
    "mean_acc", "std_acc", "max_acc",
    "mean_abs_dx", "std_dx",
    "mean_abs_dy", "std_dy",
    "mean_angles", "std_angles",
    "pause_frac",
    "bbox_aspect",
    "path_len",
    "speed_p25", "speed_p50", "speed_p75",
    "median_dt",
    "n_events"
]

# ---------- load selected features produced by Cell B ----------
selected_file = os.path.join(OUT_DIR, "mouse_selected_features.json")
selected_features_from_file = None
if os.path.exists(selected_file):
    try:
        with open(selected_file, "r", encoding="utf-8") as fh:
            loaded = json.load(fh)
            if isinstance(loaded, dict) and "selected_features" in loaded:
                selected_features_from_file = loaded["selected_features"]
            elif isinstance(loaded, list):
                selected_features_from_file = loaded
            elif isinstance(loaded, dict) and "selected" in loaded and isinstance(loaded["selected"], list):
                selected_features_from_file = loaded["selected"]
        print("Loaded selected features from", selected_file, "-> count:", 0 if selected_features_from_file is None else len(selected_features_from_file))
    except Exception:
        traceback.print_exc()
        print("Failed to parse", selected_file, "- will fallback to notebook variable or defaults.")
else:
    print("No mouse_selected_features.json in OUT_DIR; will fallback to in-memory selected_features or all features.")

# ---------- ensure session_windows present ----------
if "session_windows" not in globals() or session_windows is None:
    raise RuntimeError("session_windows not found in memory. Run preceding cells that build session_windows before Cell C.")

# ---------- build or load feature_names ----------
feature_names = globals().get("feature_names", None)
if feature_names is None:
    seq_meta_path = os.path.join(OUT_DIR, "sequence_meta.json")
    if os.path.exists(seq_meta_path):
        try:
            with open(seq_meta_path, "r", encoding="utf-8") as fh:
                seq_meta = json.load(fh)
                feature_names = seq_meta.get("feature_names", None)
                if feature_names:
                    print("Loaded feature_names from sequence_meta.json")
        except Exception:
            pass

if feature_names is None:
    # infer from first window
    try:
        sample_feats_list, _, _ = session_windows[0]
        if len(sample_feats_list) == 0:
            raise RuntimeError("First session has no windows; cannot infer feature_names")
        sample_vec = np.asarray(sample_feats_list[0])
        feature_names = [f"f{i}" for i in range(sample_vec.shape[0])]
        print("Inferred feature_names as default f0..fN-1 (N=%d)" % len(feature_names))
    except Exception as e:
        raise RuntimeError("Could not determine feature_names: " + str(e))

# ---------- if feature_names are generic f0.. and length matches canonical, replace with canonical names ----------
is_generic = all(str(fn).startswith("f") and str(fn)[1:].isdigit() for fn in feature_names)
if is_generic and len(feature_names) == len(canonical_feature_names):
    print("Detected generic feature_names (f0..). Substituting canonical feature names based on mouse_model.py")
    feature_names = canonical_feature_names.copy()
else:
    # if lengths mismatch but canonical length matches observed, consider substitution
    if len(feature_names) != len(canonical_feature_names) and len(feature_names) == len(canonical_feature_names):
        feature_names = canonical_feature_names.copy()

print("Using feature_names (count={}): {}".format(len(feature_names), feature_names if len(feature_names)<=20 else feature_names[:20]))

# ---------- determine selected_features (file > notebook > all) ----------
if selected_features_from_file:
    selected_features = selected_features_from_file
    # if indices provided, convert to names
    if all(isinstance(x, (int, np.integer)) for x in selected_features):
        try:
            idxs_from_file = [int(x) for x in selected_features]
            selected_features = [feature_names[i] for i in idxs_from_file]
            print("mouse_selected_features.json contained indices; converted to names.")
        except Exception:
            print("mouse_selected_features.json indices invalid; will fallback later.")
    else:
        # validate names exist
        missing = [f for f in selected_features if f not in feature_names]
        if missing:
            print("Warning: items in mouse_selected_features.json not found in feature_names:", missing)
            # If none matched but canonical_feature_names contains all selected, try canonical mapping
            if all(f in canonical_feature_names for f in (selected_features_from_file or [])) and len(feature_names) == len(canonical_feature_names):
                print("But selected features all exist in canonical list — using canonical mapping.")
                selected_features = [f for f in selected_features_from_file if f in canonical_feature_names]
            else:
                # remove missing entries and continue
                selected_features = [f for f in selected_features if f in feature_names]
else:
    selected_features = globals().get("selected_features", None) or feature_names
    print("Using selected_features from notebook or defaulting to all features (count=%d)" % len(selected_features))

# ---------- map selected_features to indices ----------
try:
    idxs = [feature_names.index(f) for f in selected_features]
except Exception:
    # maybe selected_features are indices already
    if all(isinstance(x, (int, np.integer)) for x in selected_features):
        idxs = [int(x) for x in selected_features]
        selected_features = [feature_names[i] for i in idxs]
    else:
        print("selected_features mismatch; falling back to all features")
        idxs = list(range(len(feature_names)))
        selected_features = [feature_names[i] for i in idxs]

print("Selected feature count:", len(idxs), "Selected features:", selected_features)

# ---------- create session ids and labels ----------
session_ids = list(range(len(session_windows)))
labels = [int(sw[1]) for sw in session_windows]

# Stratified split at session level (fallback if stratify impossible)
stratify_param = labels if len(set(labels)) > 1 else None
try:
    train_sids, test_sids = train_test_split(session_ids, test_size=TEST_SIZE, stratify=stratify_param, random_state=RANDOM_STATE)
except Exception as e:
    print("Stratified split failed:", e, "-> using non-stratified split")
    train_sids, test_sids = train_test_split(session_ids, test_size=TEST_SIZE, random_state=RANDOM_STATE)

print("Train sessions:", len(train_sids), "Test sessions:", len(test_sids))

# ---------- build per-window arrays ----------
def build_per_window_arrays(sids):
    Xw = []
    yw = []
    for sid in sids:
        feats_list, lab, _ = session_windows[sid]
        for f in feats_list:
            arr = np.asarray(f, dtype=float)
            if arr.shape[0] < max(idxs)+1:
                # skip windows that don't have expected dims
                continue
            Xw.append(arr[idxs])
            yw.append(int(lab))
    if len(Xw) == 0:
        return np.zeros((0, len(idxs))), np.zeros((0,), dtype=int)
    return np.vstack(Xw), np.array(yw, dtype=int)

Xw_train, yw_train = build_per_window_arrays(train_sids)
Xw_test, yw_test   = build_per_window_arrays(test_sids)
print("Per-window shapes -> train:", Xw_train.shape, " test:", Xw_test.shape)
try:
    print("Per-window label counts train:", np.bincount(yw_train), " test:", np.bincount(yw_test))
except Exception:
    pass

# ---------- build sequence arrays (SEQ_LEN windows per sequence) ----------
def build_sequences_for_sids(sids, seq_len=SEQ_LEN, stride=STRIDE):
    Xs, ys, sid_map = [], [], []
    for sid in sids:
        feats_list, lab, _ = session_windows[sid]
        if len(feats_list) == 0:
            seq = np.zeros((seq_len, len(idxs)), dtype=float)
            Xs.append(seq); ys.append(int(lab)); sid_map.append(sid)
            continue
        arr = np.vstack([np.asarray(f, dtype=float)[idxs] for f in feats_list])  # (n_windows, feat_dim)
        n_windows = arr.shape[0]
        if n_windows >= seq_len:
            for i in range(0, n_windows - seq_len + 1, stride):
                Xs.append(arr[i:i+seq_len]); ys.append(int(lab)); sid_map.append(sid)
        else:
            pad = np.zeros((max(0, seq_len - n_windows), arr.shape[1]), dtype=float)
            seq = np.vstack([arr, pad])
            Xs.append(seq); ys.append(int(lab)); sid_map.append(sid)
    if len(Xs) == 0:
        return np.zeros((0, seq_len, len(idxs))), np.zeros((0,), dtype=int), np.array([], dtype=int)
    return np.asarray(Xs, dtype=float), np.asarray(ys, dtype=int), np.asarray(sid_map, dtype=int)

Xseq_train, yseq_train, seq_sid_train = build_sequences_for_sids(train_sids, seq_len=SEQ_LEN, stride=STRIDE)
Xseq_test, yseq_test, seq_sid_test   = build_sequences_for_sids(test_sids, seq_len=SEQ_LEN, stride=STRIDE)
print("Seq shapes -> train:", Xseq_train.shape, " test:", Xseq_test.shape)
try:
    print("Seq label counts train:", np.bincount(yseq_train), " test:", np.bincount(yseq_test))
except Exception:
    pass

# ---------- save outputs to OUT_DIR (so downstream cells can load after kernel restart) ----------
def safe_save(obj, name, is_json=False):
    path = os.path.join(OUT_DIR, name if (name.endswith(".npy") or name.endswith(".json")) else (name + (".json" if is_json else ".npy")))
    try:
        if is_json:
            with open(path, "w", encoding="utf-8") as fh:
                json.dump(obj, fh, indent=2)
        else:
            np.save(path, obj, allow_pickle=True)
        print("Saved", path)
    except Exception:
        traceback.print_exc()
        print("Failed to save", path)

# per-window arrays
safe_save(Xw_train, "Xw_train.npy")
safe_save(yw_train, "yw_train.npy")
safe_save(Xw_test, "Xw_test.npy")
safe_save(yw_test, "yw_test.npy")
# sequences
safe_save(Xseq_train, "Xseq_train.npy")
safe_save(yseq_train, "yseq_train.npy")
safe_save(Xseq_test, "Xseq_test.npy")
safe_save(yseq_test, "yseq_test.npy")
# split & meta
split_info = {"train_sids": list(train_sids), "test_sids": list(test_sids), "seq_len": SEQ_LEN, "selected_features": selected_features, "selected_indices": idxs, "feature_names": feature_names}
safe_save(split_info, "session_split.json", is_json=True)

sequence_meta = {
    "seq_len": SEQ_LEN,
    "feat_dim": int(Xseq_train.shape[2]) if Xseq_train.size else len(idxs),
    "selected_features": selected_features,
    "selected_indices": idxs,
    "feature_names": feature_names
}
safe_save(sequence_meta, "sequence_meta.json", is_json=True)

print("Cell C complete — saved arrays and meta to", OUT_DIR)

In [None]:
# Cell D (patched) - Train final RF and save scaler + model (robust, saves into OUT_DIR)
import os
import joblib
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
import multiprocessing
import json, traceback

# ---------- CONFIG / OUT_DIR detection ----------
def find_project_root(max_up=6):
    from pathlib import Path
    p = Path.cwd()
    for _ in range(max_up+1):
        if (p / "backend").exists() or (p / ".git").exists():
            return str(p.resolve())
        if p.parent == p:
            break
        p = p.parent
    return str(Path.cwd().resolve())

ROOT = os.environ.get("PROJECT_ROOT") or find_project_root()
OUT_DIR = os.environ.get("OUT_DIR", os.path.join(ROOT, "data", "processed"))
os.makedirs(OUT_DIR, exist_ok=True)

print("ROOT:", ROOT)
print("OUT_DIR:", OUT_DIR)

# ---------- sanity checks ----------
required_vars = ["Xw_train", "yw_train", "Xw_test", "yw_test"]
missing = [v for v in required_vars if v not in globals()]
if missing:
    raise RuntimeError(f"Missing variables in memory required for training: {missing}. Run Cell C first or load saved .npy files into notebook.")

# ensure numpy arrays
Xw_train = np.asarray(globals()["Xw_train"], dtype=float)
Xw_test  = np.asarray(globals()["Xw_test"], dtype=float)
yw_train = np.asarray(globals()["yw_train"], dtype=int)
yw_test  = np.asarray(globals()["yw_test"], dtype=int)

print("Shapes: Xw_train", Xw_train.shape, "yw_train", yw_train.shape, "Xw_test", Xw_test.shape, "yw_test", yw_test.shape)

# ---------- handle edge cases ----------
if Xw_train.size == 0 or Xw_test.size == 0:
    raise RuntimeError("Empty training or test arrays (check session_windows / selected features).")

# choose n_jobs sensibly (leave 1 core free if possible)
cpu_count = multiprocessing.cpu_count()
n_jobs_env = os.environ.get("RF_N_JOBS")
try:
    n_jobs = int(n_jobs_env) if n_jobs_env else max(1, cpu_count - 1)
except Exception:
    n_jobs = max(1, cpu_count - 1)

print("Using n_jobs =", n_jobs, " (cpu_count reported as", cpu_count, ")")

# ---------- scale features ----------
scaler_rf = StandardScaler().fit(Xw_train)
Xw_train_s = scaler_rf.transform(Xw_train)
Xw_test_s  = scaler_rf.transform(Xw_test)

# ---------- train RandomForest ----------
rf_final = RandomForestClassifier(
    n_estimators=int(os.environ.get("RF_N_ESTIMATORS", 300)),
    class_weight=os.environ.get("RF_CLASS_WEIGHT", "balanced"),
    n_jobs=n_jobs,
    random_state=int(os.environ.get("RANDOM_STATE", 42))
)

print("Training final RF (selected features)...")
try:
    rf_final.fit(Xw_train_s, yw_train)
except Exception:
    traceback.print_exc()
    raise

# ---------- evaluation (per-window) ----------
print("\n--- RF per-window evaluation (test set) ---")
yhat = rf_final.predict(Xw_test_s)
proba = None
try:
    proba = rf_final.predict_proba(Xw_test_s)[:, 1]
except Exception:
    # some sklearn classifiers can fail to provide predict_proba
    pass

print("Classification report:")
print(classification_report(yw_test, yhat, digits=4))

try:
    cm = confusion_matrix(yw_test, yhat)
    print("Confusion matrix:\n", cm)
except Exception:
    pass

if proba is not None and len(np.unique(yw_test)) > 1:
    try:
        auc = roc_auc_score(yw_test, proba)
        print("ROC AUC (per-window): {:.6f}".format(auc))
    except Exception:
        print("ROC AUC computation failed.")
else:
    print("ROC AUC skipped (no probabilities or single-class test set).")

# ---------- save artifacts ----------
rf_path = os.path.join(OUT_DIR, "mouse_rf.save")
scaler_path = os.path.join(OUT_DIR, "mouse_scaler.save")

try:
    joblib.dump(rf_final, rf_path)
    joblib.dump(scaler_rf, scaler_path)
    print(f"Saved RF model -> {rf_path}")
    print(f"Saved scaler -> {scaler_path}")
except Exception:
    traceback.print_exc()
    raise

# also save a small metadata JSON (feature names / shapes) if available
meta = {}
if "feature_names" in globals():
    meta["feature_names"] = globals().get("feature_names")
if "selected_features" in globals():
    meta["selected_features"] = globals().get("selected_features")
meta.update({
    "Xw_train_shape": list(Xw_train.shape),
    "Xw_test_shape": list(Xw_test.shape),
    "yw_train_counts": list(map(int, np.bincount(yw_train))) if yw_train.size else [],
    "yw_test_counts": list(map(int, np.bincount(yw_test))) if yw_test.size else []
})

try:
    with open(os.path.join(OUT_DIR, "mouse_rf_meta.json"), "w", encoding="utf-8") as fh:
        json.dump(meta, fh, indent=2)
    print("Saved mouse_rf_meta.json")
except Exception:
    traceback.print_exc()
    print("Failed to save mouse_rf_meta.json")

print("Cell D complete.")

In [None]:
# Cell E (patched) - Train LSTM (sequence model), save mouse_lstm.h5, mouse_lstm.keras, mouse_lstm_scaler.save, mouse_lstm_meta.json
import os
import json
import sys
import traceback
import numpy as np
import joblib
from pathlib import Path

# ---------- OUT_DIR / ROOT detection ----------
def find_project_root(max_up=6):
    p = Path.cwd()
    for _ in range(max_up + 1):
        if (p / "backend").exists() or (p / ".git").exists():
            return str(p.resolve())
        if p.parent == p:
            break
        p = p.parent
    return str(Path.cwd().resolve())

ROOT = os.environ.get("PROJECT_ROOT") or find_project_root()
OUT_DIR = os.environ.get("OUT_DIR", os.path.join(ROOT, "data", "processed"))
os.makedirs(OUT_DIR, exist_ok=True)
print("ROOT:", ROOT)
print("OUT_DIR:", OUT_DIR)

# ---------- check for required saved arrays (Cell C outputs) ----------
def load_npy(name):
    p = os.path.join(OUT_DIR, name)
    if not os.path.exists(p):
        raise FileNotFoundError(f"Required file not found: {p}")
    return np.load(p, allow_pickle=True)

try:
    Xseq_train = load_npy("Xseq_train.npy")
    yseq_train = load_npy("yseq_train.npy")
    Xseq_test  = load_npy("Xseq_test.npy")
    yseq_test  = load_npy("yseq_test.npy")
except Exception as e:
    print("Could not load sequence arrays from OUT_DIR:", e)
    raise

print("Loaded sequence arrays: Xseq_train", Xseq_train.shape, "yseq_train", yseq_train.shape,
      "Xseq_test", Xseq_test.shape, "yseq_test", yseq_test.shape)

# ---------- TensorFlow import (fail gracefully) ----------
try:
    import tensorflow as tf
    from tensorflow.keras import layers, models, callbacks, optimizers, metrics
except Exception:
    traceback.print_exc()
    raise RuntimeError("TensorFlow import failed. Install TensorFlow in the venv to train LSTM.")

# ---------- config / hyperparams (env override) ----------
SEQ_LEN = int(os.environ.get("SEQ_LEN", Xseq_train.shape[1]))
FEAT_DIM = int(os.environ.get("FEAT_DIM", Xseq_train.shape[2]))
BATCH_SIZE = int(os.environ.get("LSTM_BATCH", 64))
EPOCHS = int(os.environ.get("LSTM_EPOCHS", 40))
PATIENCE = int(os.environ.get("LSTM_PATIENCE", 6))
LR = float(os.environ.get("LSTM_LR", 1e-3))

OUT_MODEL_H5 = os.path.join(OUT_DIR, "mouse_lstm.h5")
OUT_MODEL_KERAS = os.path.join(OUT_DIR, "mouse_lstm.keras")
OUT_SCALER = os.path.join(OUT_DIR, "mouse_lstm_scaler.save")
OUT_META = os.path.join(OUT_DIR, "mouse_lstm_meta.json")

print("SEQ_LEN:", SEQ_LEN, "FEAT_DIM:", FEAT_DIM, "BATCH_SIZE:", BATCH_SIZE, "EPOCHS:", EPOCHS, "LR:", LR)

# ---------- prepare flattened scaler (fit on training sequences flattened along time axis) ----------
from sklearn.preprocessing import StandardScaler

# flatten sequences to (n_samples * seq_len, feat_dim) for per-feature scaling
Xtr_flat = Xseq_train.reshape(-1, FEAT_DIM)
Xte_flat = Xseq_test.reshape(-1, FEAT_DIM)

scaler = StandardScaler().fit(Xtr_flat)

def scale_seqs(X):
    flat = X.reshape(-1, FEAT_DIM)
    s = scaler.transform(flat)
    return s.reshape(X.shape)

Xseq_train_s = scale_seqs(Xseq_train).astype(np.float32)
Xseq_test_s  = scale_seqs(Xseq_test).astype(np.float32)

# save scaler now (joblib)
joblib.dump(scaler, OUT_SCALER)
print("Saved LSTM scaler ->", OUT_SCALER)

# ---------- build model (small, CPU-friendly) ----------
def build_lstm_model(seq_len=SEQ_LEN, feat_dim=FEAT_DIM, dropout=0.2, lstm_units=64, dense_units=32):
    inp = layers.Input(shape=(seq_len, feat_dim), name="seq_input")
    # Use a serializable Masking layer (no custom TF ops)
    x = layers.Masking(mask_value=0.0, name="masking")(inp)
    x = layers.LSTM(lstm_units, return_sequences=False, name="lstm")(x)
    x = layers.BatchNormalization(name="batch_norm")(x)
    x = layers.Dropout(dropout, name="dropout1")(x)
    x = layers.Dense(dense_units, activation="relu", name="dense1")(x)
    x = layers.Dropout(dropout * 0.5, name="dropout2")(x)
    out = layers.Dense(1, activation="sigmoid", name="out")(x)
    model = models.Model(inputs=inp, outputs=out, name="mouse_lstm_model")
    return model

model = build_lstm_model()
optimizer = optimizers.Adam(learning_rate=LR)
auc_metric = metrics.AUC(name="auc")
model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=[auc_metric])
model.summary()

# ---------- prepare training callbacks ----------
es = callbacks.EarlyStopping(monitor="val_auc", mode="max", patience=PATIENCE, restore_best_weights=True, verbose=1)
mc = callbacks.ModelCheckpoint(OUT_MODEL_H5, save_best_only=True, monitor="val_auc", mode="max", verbose=1)
rlr = callbacks.ReduceLROnPlateau(monitor="val_auc", mode="max", factor=0.5, patience=max(2, PATIENCE//2), min_lr=1e-6, verbose=1)

# ---------- class weights (optional to handle imbalance) ----------
try:
    from sklearn.utils.class_weight import compute_class_weight
    classes = np.unique(yseq_train)
    if len(classes) > 1:
        cw = compute_class_weight(class_weight="balanced", classes=classes, y=yseq_train)
        class_weight = {int(c): float(w) for c, w in zip(classes, cw)}
        print("Class weight:", class_weight)
    else:
        class_weight = None
except Exception:
    class_weight = None

# ---------- train ----------
print("Training LSTM (CPU-friendly)...")
history = None
try:
    history = model.fit(
        Xseq_train_s, yseq_train,
        validation_split=0.15,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=[es, mc, rlr],
        class_weight=class_weight,
        verbose=2
    )
except Exception:
    traceback.print_exc()
    raise

# ---------- save final model (if checkpoint didn't produce file) ----------
if not os.path.exists(OUT_MODEL_H5):
    try:
        model.save(OUT_MODEL_H5)
        print("Saved model to", OUT_MODEL_H5)
    except Exception:
        traceback.print_exc()
        print("Warning: failed to save model via model.save(h5)")

# Also save native Keras (.keras) which is preferred for portability
try:
    model.save(OUT_MODEL_KERAS)   # native Keras format (zip)
    print("Saved native Keras model ->", OUT_MODEL_KERAS)
except Exception:
    traceback.print_exc()
    print("Warning: failed to save native .keras model (this is non-fatal)")

# ---------- evaluate on test set ----------
print("Evaluating on test set...")
res = model.evaluate(Xseq_test_s, yseq_test, batch_size=BATCH_SIZE, verbose=2)
print("Test metrics (loss + auc):", res)

# ---------- save metadata ----------
# build class counts safely
train_vals, train_counts = np.unique(yseq_train, return_counts=True)
test_vals, test_counts = np.unique(yseq_test, return_counts=True)
meta = {
    "seq_len": SEQ_LEN,
    "feat_dim": FEAT_DIM,
    "model_file_h5": os.path.abspath(OUT_MODEL_H5),
    "model_file_keras": os.path.abspath(OUT_MODEL_KERAS) if os.path.exists(OUT_MODEL_KERAS) else None,
    "scaler_file": os.path.abspath(OUT_SCALER),
    "train_shape": list(Xseq_train_s.shape),
    "test_shape": list(Xseq_test_s.shape),
    "class_counts_train": {int(k): int(v) for k, v in zip(train_vals, train_counts)} if train_vals.size else {},
    "class_counts_test":  {int(k): int(v) for k, v in zip(test_vals, test_counts)} if test_vals.size else {}
}
with open(OUT_META, "w", encoding="utf-8") as fh:
    json.dump(meta, fh, indent=2)
print("Saved LSTM metadata ->", OUT_META)

print("Cell E complete. LSTM artifacts written to OUT_DIR:")
print(" -", OUT_MODEL_H5)
print(" -", OUT_MODEL_KERAS if os.path.exists(OUT_MODEL_KERAS) else "(native .keras not saved)")
print(" -", OUT_SCALER)
print(" -", OUT_META)

In [None]:
# Cell F (patched) - Ensemble RF + LSTM evaluation (robust: loads artifacts if missing, saves ensemble meta)
import os, json, joblib, traceback
import numpy as np

# ---------- detect project root / OUT_DIR ----------
from pathlib import Path
def find_project_root(max_up=6):
    p = Path.cwd()
    for _ in range(max_up+1):
        if (p / "backend").exists() or (p / ".git").exists():
            return str(p.resolve())
        if p.parent == p:
            break
        p = p.parent
    return str(Path.cwd().resolve())

ROOT = os.environ.get("PROJECT_ROOT") or find_project_root()
OUT_DIR = os.environ.get("OUT_DIR", os.path.join(ROOT, "data", "processed"))
os.makedirs(OUT_DIR, exist_ok=True)
print("ROOT:", ROOT)
print("OUT_DIR:", OUT_DIR)

# ---------- helper to load numpy arrays saved by previous cells ----------
def load_npy_if_missing(varname, filename):
    if varname in globals():
        return globals()[varname]
    path = os.path.join(OUT_DIR, filename)
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required array {filename} not found at {path} and {varname} not in memory.")
    print(f"Loading {filename} from disk")
    return np.load(path, allow_pickle=True)

# ---------- load sequence & mapping arrays ----------
Xseq_test = load_npy_if_missing("Xseq_test", "Xseq_test.npy")
yseq_test = load_npy_if_missing("yseq_test", "yseq_test.npy")

# seq_sid_test maps each sequence back to session id (used for session-level aggregation)
if "seq_sid_test" in globals():
    seq_sid_test = globals()["seq_sid_test"]
else:
    ss_path = os.path.join(OUT_DIR, "session_split.json")
    if os.path.exists(ss_path):
        try:
            ss = json.load(open(ss_path, "r", encoding="utf-8"))
            seq_sid_test = np.asarray(ss.get("seq_sid_test") or ss.get("seq_sid_test.npy") or [], dtype=int)
        except Exception:
            seq_sid_test = None
    else:
        seq_sid_test = None

# ---------- load RF artifacts (scaler_rf + rf_final) ----------
def load_joblib_if_missing(varname, filename):
    if varname in globals() and globals()[varname] is not None:
        return globals()[varname]
    path = os.path.join(OUT_DIR, filename)
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required model file {filename} not found at {path} and {varname} not in memory.")
    print(f"Loading {filename} from disk")
    return joblib.load(path)

try:
    scaler_rf = load_joblib_if_missing("scaler_rf", "mouse_scaler.save")
    rf_final = load_joblib_if_missing("rf_final", "mouse_rf.save")
except Exception:
    # try alternative names
    try:
        scaler_rf = load_joblib_if_missing("scaler_rf", "mouse_scaler.joblib")
    except Exception:
        raise
    rf_final = load_joblib_if_missing("rf_final", "mouse_rf.save")

# ---------- compute RF sequence-level probabilities ----------
n_seq_test = int(Xseq_test.shape[0])
seq_len = int(Xseq_test.shape[1])
feat_dim = int(Xseq_test.shape[2])
print("Xseq_test shape:", Xseq_test.shape)

# Flatten windows -> scale with scaler_rf -> predict proba -> reshape
flat_windows = Xseq_test.reshape(-1, feat_dim)
try:
    flat_windows_scaled_for_rf = scaler_rf.transform(flat_windows)
except Exception:
    # if scaler expects different feature order/shape, attempt safe conversion (raise helpful error)
    raise RuntimeError(f"scaler_rf.transform failed for flat_windows of shape {flat_windows.shape}. Check that selected features used for RF match sequence features.")

rf_flat_probs = rf_final.predict_proba(flat_windows_scaled_for_rf)[:, 1]
rf_probs_seq = rf_flat_probs.reshape(n_seq_test, seq_len)
rf_seq_prob = rf_probs_seq.mean(axis=1)

# ---------- prepare LSTM predictions ----------
# load LSTM model & scaler if missing
try:
    import tensorflow as tf
    from tensorflow.keras.models import load_model as tf_load_model
except Exception:
    traceback.print_exc()
    raise RuntimeError("TensorFlow import failed. Install TF to run LSTM inference.")

# load LSTM model if not in memory
if "model" in globals() and getattr(globals()["model"], "predict", None):
    lstm_model = globals()["model"]
else:
    # prefer model file names placed earlier
    model_path_candidates = [
        os.path.join(OUT_DIR, "mouse_lstm.h5"),
        os.path.join(OUT_DIR, "mouse_lstm.keras")
    ]
    model_path = next((p for p in model_path_candidates if os.path.exists(p)), None)
    if not model_path:
        raise FileNotFoundError("LSTM model not found in memory nor mouse_lstm.h5/mouse_lstm.keras in OUT_DIR.")
    print("Loading LSTM model from:", model_path)
    lstm_model = tf_load_model(model_path)

# load LSTM scaler used in Cell E (name: mouse_lstm_scaler.save)
if "scaler" in globals():
    lstm_scaler = globals()["scaler"]
else:
    scaler_candidates = [
        os.path.join(OUT_DIR, "mouse_lstm_scaler.save"),
        os.path.join(OUT_DIR, "mouse_lstm_scaler.joblib"),
        os.path.join(OUT_DIR, "mouse_lstm_scaler.pkl")
    ]
    scaler_path = next((p for p in scaler_candidates if os.path.exists(p)), None)
    if not scaler_path:
        raise FileNotFoundError("LSTM scaler not found in memory nor in OUT_DIR (mouse_lstm_scaler.save).")
    lstm_scaler = joblib.load(scaler_path)
    print("Loaded LSTM scaler from:", scaler_path)

# scale sequences same as during training: flatten -> transform -> reshape
flat_seq_test = Xseq_test.reshape(-1, feat_dim)
flat_seq_test_s = lstm_scaler.transform(flat_seq_test)
Xseq_test_s = flat_seq_test_s.reshape(Xseq_test.shape).astype(np.float32)

# compute LSTM sequence-level probs
try:
    lstm_seq_prob = lstm_model.predict(Xseq_test_s, batch_size=int(os.environ.get("LSTM_BATCH", 64)), verbose=1).ravel()
except Exception:
    # fallback: try smaller batch
    lstm_seq_prob = lstm_model.predict(Xseq_test_s, batch_size=16, verbose=1).ravel()

# ---------- ensemble ----------
w_rf = float(os.environ.get("ENS_WEIGHT_RF", 0.5))
w_lstm = float(os.environ.get("ENS_WEIGHT_LSTM", 0.5))
ensemble_seq_prob = (w_rf * rf_seq_prob + w_lstm * lstm_seq_prob) / (w_rf + w_lstm)
ensemble_seq_pred = (ensemble_seq_prob >= float(os.environ.get("ENS_THRESHOLD", 0.5))).astype(int)
lstm_seq_pred = (lstm_seq_prob >= float(os.environ.get("ENS_THRESHOLD", 0.5))).astype(int)
rf_seq_pred = (rf_seq_prob >= float(os.environ.get("ENS_THRESHOLD", 0.5))).astype(int)

# ---------- evaluation (per-sequence) ----------
from sklearn.metrics import classification_report, roc_auc_score

print("\nLSTM seq report:")
print(classification_report(yseq_test, lstm_seq_pred, digits=4))
try:
    print("LSTM seq AUC:", roc_auc_score(yseq_test, lstm_seq_prob))
except Exception:
    print("LSTM seq AUC: could not compute (single-class or error)")

print("\nRF seq report (avg windows):")
print(classification_report(yseq_test, rf_seq_pred, digits=4))
try:
    print("RF seq AUC:", roc_auc_score(yseq_test, rf_seq_prob))
except Exception:
    print("RF seq AUC: could not compute")

print("\nENSEMBLE (avg) seq report:")
print(classification_report(yseq_test, ensemble_seq_pred, digits=4))
try:
    print("ENSEMBLE seq AUC:", roc_auc_score(yseq_test, ensemble_seq_prob))
except Exception:
    print("ENSEMBLE seq AUC: could not compute")

# ---------- session-level evaluation (average sequence probs per session) ----------
if seq_sid_test is None or len(seq_sid_test) != len(ensemble_seq_prob):
    # attempt to load seq_sid_test from saved session_split.json
    ss_path = os.path.join(OUT_DIR, "session_split.json")
    if os.path.exists(ss_path):
        try:
            ss = json.load(open(ss_path, "r", encoding="utf-8"))
            seq_sid_test = np.asarray(ss.get("seq_sid_test") or ss.get("seq_sid_test.npy") or ss.get("seq_sid_test_list") or [], dtype=int)
            print("Loaded seq_sid_test from session_split.json")
        except Exception:
            seq_sid_test = None

if seq_sid_test is not None and len(seq_sid_test) == len(ensemble_seq_prob):
    sess_ids = np.unique(seq_sid_test)
    sess_true = []
    sess_rf_prob = []
    sess_lstm_prob = []
    sess_ensemble_prob = []
    for sid in sess_ids:
        mask = (seq_sid_test == sid)
        if np.sum(mask) == 0: 
            continue
        true_label = int(yseq_test[mask][0])
        sess_true.append(true_label)
        sess_rf_prob.append(float(rf_seq_prob[mask].mean()))
        sess_lstm_prob.append(float(lstm_seq_prob[mask].mean()))
        sess_ensemble_prob.append(float(ensemble_seq_prob[mask].mean()))

    sess_true = np.array(sess_true)
    sess_rf_prob = np.array(sess_rf_prob)
    sess_lstm_prob = np.array(sess_lstm_prob)
    sess_ensemble_prob = np.array(sess_ensemble_prob)

    print("\nSession-level evaluation (averaging sequence probs per session):")
    print("\nRF session-level:")
    rf_sess_pred = (sess_rf_prob >= float(os.environ.get("ENS_THRESHOLD", 0.5))).astype(int)
    print(classification_report(sess_true, rf_sess_pred, digits=4))
    try: print("RF session AUC:", roc_auc_score(sess_true, sess_rf_prob))
    except: print("RF session AUC: could not compute")

    print("\nLSTM session-level:")
    lstm_sess_pred = (sess_lstm_prob >= float(os.environ.get("ENS_THRESHOLD", 0.5))).astype(int)
    print(classification_report(sess_true, lstm_sess_pred, digits=4))
    try: print("LSTM session AUC:", roc_auc_score(sess_true, sess_lstm_prob))
    except: print("LSTM session AUC: could not compute")

    print("\nEnsemble session-level:")
    ens_sess_pred = (sess_ensemble_prob >= float(os.environ.get("ENS_THRESHOLD", 0.5))).astype(int)
    print(classification_report(sess_true, ens_sess_pred, digits=4))
    try: print("Ensemble session AUC:", roc_auc_score(sess_true, sess_ensemble_prob))
    except: print("Ensemble session AUC: could not compute")
else:
    print("\nSkipping session-level evaluation: seq_sid_test unavailable or length mismatch.")

# ---------- save ensemble metadata ----------
ens_meta = {
    "rf_weight": w_rf,
    "lstm_weight": w_lstm,
    "method": os.environ.get("ENS_METHOD", "weighted_avg"),
    "threshold": float(os.environ.get("ENS_THRESHOLD", 0.5)),
    "n_sequences_test": int(len(ensemble_seq_prob)),
    "rf_seq_shape": list(rf_seq_prob.shape),
    "lstm_seq_shape": list(lstm_seq_prob.shape),
    "created_at": __import__("datetime").datetime.utcnow().isoformat() + "Z"
}
with open(os.path.join(OUT_DIR, "mouse_ensemble_meta.json"), "w", encoding="utf-8") as fh:
    json.dump(ens_meta, fh, indent=2)
print("Saved ensemble meta to", os.path.join(OUT_DIR, "mouse_ensemble_meta.json"))

In [None]:
# Cell G - Visualization (inline display + save): ROC curves + Confusion Matrices for RF / LSTM / Ensemble
import os, json, joblib, traceback
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, ConfusionMatrixDisplay

# Ensure inline plotting when running in a Jupyter notebook (safe check)
try:
    get_ipython().run_line_magic("matplotlib", "inline")
except Exception:
    pass

# ----------------- OUT_DIR discovery -----------------
def find_project_root(max_up=6):
    p = Path.cwd()
    for _ in range(max_up+1):
        if (p / "backend").exists() or (p / ".git").exists():
            return str(p.resolve())
        if p.parent == p:
            break
        p = p.parent
    return str(Path.cwd().resolve())

ROOT = os.environ.get("PROJECT_ROOT") or find_project_root()
OUT_DIR = os.environ.get("OUT_DIR", os.path.join(ROOT, "data", "processed"))
os.makedirs(OUT_DIR, exist_ok=True)
print("ROOT:", ROOT)
print("OUT_DIR:", OUT_DIR)

# ----------------- helpers -----------------
def load_npy(varname, fname):
    if varname in globals():
        return globals()[varname]
    p = os.path.join(OUT_DIR, fname)
    if not os.path.exists(p):
        raise FileNotFoundError(f"{fname} missing at {p}")
    return np.load(p, allow_pickle=True)

def safe_joblib_load(varname, fname):
    if varname in globals() and globals()[varname] is not None:
        return globals()[varname]
    path = os.path.join(OUT_DIR, fname)
    if not os.path.exists(path):
        return None
    return joblib.load(path)

# ----------------- load data arrays -----------------
Xseq_test = load_npy("Xseq_test", "Xseq_test.npy")
yseq_test = load_npy("yseq_test", "yseq_test.npy")
print("Loaded Xseq_test", Xseq_test.shape, "yseq_test", yseq_test.shape)

# optional seq->session mapping (for session-level plotting)
seq_sid_test = globals().get("seq_sid_test", None)
ss_path = os.path.join(OUT_DIR, "session_split.json")
if seq_sid_test is None and os.path.exists(ss_path):
    try:
        ss = json.load(open(ss_path, "r", encoding="utf-8"))
        seq_sid_test = np.asarray(ss.get("seq_sid_test") or ss.get("seq_sid_test_list") or [], dtype=int)
        print("Loaded seq_sid_test from session_split.json")
    except Exception:
        seq_sid_test = None

# ----------------- load RF artifacts -----------------
scaler_rf = globals().get("scaler_rf", None) or safe_joblib_load("scaler_rf", "mouse_scaler.save")
rf_final = globals().get("rf_final", None) or safe_joblib_load("rf_final", "mouse_rf.save")
if scaler_rf is None or rf_final is None:
    raise RuntimeError("RF model or scaler missing. Ensure mouse_rf.save & mouse_scaler.save exist in OUT_DIR or are present in globals().")

# ----------------- load LSTM model & scaler -----------------
try:
    import tensorflow as tf
    from tensorflow.keras.models import load_model as tf_load_model
except Exception:
    tf = None
    tf_load_model = None

lstm_model = globals().get("model", None)
if lstm_model is None:
    m1 = os.path.join(OUT_DIR, "mouse_lstm.h5")
    m2 = os.path.join(OUT_DIR, "mouse_lstm.keras")
    model_path = m1 if os.path.exists(m1) else (m2 if os.path.exists(m2) else None)
    if model_path and tf_load_model:
        lstm_model = tf_load_model(model_path)
        print("Loaded LSTM model from", model_path)
    else:
        raise RuntimeError("LSTM model missing. Ensure mouse_lstm.h5 or mouse_lstm.keras exists in OUT_DIR.")

lstm_scaler = globals().get("lstm_scaler", None) or safe_joblib_load("lstm_scaler", "mouse_lstm_scaler.save")
if lstm_scaler is None:
    raise RuntimeError("LSTM scaler missing. Ensure mouse_lstm_scaler.save exists in OUT_DIR or is in globals().")

# ----------------- compute RF sequence probs -----------------
n_seq = Xseq_test.shape[0]
seq_len = Xseq_test.shape[1]
fdim = Xseq_test.shape[2]

flat_windows = Xseq_test.reshape(-1, fdim)
flat_windows_scaled = scaler_rf.transform(flat_windows)
rf_probs_flat = rf_final.predict_proba(flat_windows_scaled)[:, 1]
rf_probs_seq = rf_probs_flat.reshape(n_seq, seq_len)
rf_seq_prob = rf_probs_seq.mean(axis=1)

# ----------------- compute LSTM sequence probs -----------------
flat_seq = Xseq_test.reshape(-1, fdim)
flat_seq_s = lstm_scaler.transform(flat_seq)
Xseq_test_s = flat_seq_s.reshape(Xseq_test.shape).astype("float32")

# predict (show progress bar)
lstm_seq_prob = lstm_model.predict(Xseq_test_s, batch_size=int(os.environ.get("LSTM_BATCH", 64)), verbose=1).ravel()

# ----------------- ensemble -----------------
w_rf = float(os.environ.get("ENS_WEIGHT_RF", 0.5))
w_lstm = float(os.environ.get("ENS_WEIGHT_LSTM", 0.5))
ensemble_seq_prob = (w_rf * rf_seq_prob + w_lstm * lstm_seq_prob) / (w_rf + w_lstm)
threshold = float(os.environ.get("ENS_THRESHOLD", 0.5))

rf_pred = (rf_seq_prob >= threshold).astype(int)
lstm_pred = (lstm_seq_prob >= threshold).astype(int)
ens_pred = (ensemble_seq_prob >= threshold).astype(int)

# ----------------- plotting helpers -----------------
def plot_and_save(fig, fname):
    path = os.path.join(OUT_DIR, fname)
    fig.savefig(path, bbox_inches="tight")
    print("Saved figure ->", path)
    plt.show()

# 1) ROC plots (one figure each)
from sklearn.metrics import roc_curve

def single_roc_plot(y_true, y_score, title):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6,5))
    plt.plot(fpr, tpr)
    plt.plot([0,1],[0,1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} ROC (AUC={a:.4f})")
    return fig, a

fig, a_rf = single_roc_plot(yseq_test, rf_seq_prob, "RF (sequence-level)")
plot_and_save(fig, "roc_rf_seq.png")

fig, a_lstm = single_roc_plot(yseq_test, lstm_seq_prob, "LSTM (sequence-level)")
plot_and_save(fig, "roc_lstm_seq.png")

fig, a_ens = single_roc_plot(yseq_test, ensemble_seq_prob, "Ensemble (sequence-level)")
plot_and_save(fig, "roc_ensemble_seq.png")

# 2) Confusion matrices (one figure each)
def confmat_plot(y_true, y_pred, title):
    cm = confusion_matrix(y_true, y_pred)
    fig = plt.figure(figsize=(5,4))
    disp = ConfusionMatrixDisplay(cm)
    disp.plot(ax=fig.gca(), cmap=None, colorbar=False)
    plt.title(title)
    return fig, cm

fig, cm_rf = confmat_plot(yseq_test, rf_pred, "RF confusion matrix (sequence)")
plot_and_save(fig, "confmat_rf_seq.png")

fig, cm_l = confmat_plot(yseq_test, lstm_pred, "LSTM confusion matrix (sequence)")
plot_and_save(fig, "confmat_lstm_seq.png")

fig, cm_e = confmat_plot(yseq_test, ens_pred, "Ensemble confusion matrix (sequence)")
plot_and_save(fig, "confmat_ensemble_seq.png")

# 3) Print classification reports + AUCs
print("\n=== Metrics summary (sequence-level) ===\n")
print(f"RF AUC: {a_rf:.6f}")
print(classification_report(yseq_test, rf_pred, digits=4))
print(f"\nLSTM AUC: {a_lstm:.6f}")
print(classification_report(yseq_test, lstm_pred, digits=4))
print(f"\nEnsemble AUC: {a_ens:.6f}")
print(classification_report(yseq_test, ens_pred, digits=4))

# 4) Session-level ROC (optional)
if seq_sid_test is not None and len(seq_sid_test) == len(ensemble_seq_prob):
    sess_ids = np.unique(seq_sid_test)
    sess_true = []
    sess_rf = []; sess_l = []; sess_e = []
    for sid in sess_ids:
        mask = (seq_sid_test == sid)
        if mask.sum() == 0: 
            continue
        sess_true.append(int(yseq_test[mask][0]))
        sess_rf.append(float(rf_seq_prob[mask].mean()))
        sess_l.append(float(lstm_seq_prob[mask].mean()))
        sess_e.append(float(ensemble_seq_prob[mask].mean()))
    sess_true = np.array(sess_true)
    sess_e = np.array(sess_e)

    fpr_s, tpr_s, _ = roc_curve(sess_true, sess_e)
    a_sess = auc(fpr_s, tpr_s)
    fig = plt.figure(figsize=(6,5))
    plt.plot(fpr_s, tpr_s)
    plt.plot([0,1],[0,1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"Ensemble ROC (session-level) AUC={a_sess:.4f}")
    savepath = os.path.join(OUT_DIR, "roc_ensemble_session.png")
    fig.savefig(savepath, bbox_inches="tight")
    print("Saved session-level ROC ->", savepath)
    plt.show()
else:
    print("\nSkipping session-level ROC plot: seq_sid_test missing or length mismatch.")

print("\nCell G complete. Figures displayed inline and saved to OUT_DIR.")