In [1]:
import os

# --- Step 1: Define folder paths ---
trials_folder = "first trials"   # path to your trials folder
repeats_folder = "repeats" # path to your repeats folder

# --- Step 2: Function to extract participant IDs ---
def get_participant_ids(folder):
    ids = set()
    for fname in os.listdir(folder):
        if fname.endswith(".csv"):
            pid = fname.split("_")[0]   # participant id is before "_"
            ids.add(pid)
    return ids

# --- Step 3: Collect IDs from both folders ---
trials_ids = get_participant_ids(trials_folder)
repeats_ids = get_participant_ids(repeats_folder)

all_ids = trials_ids.union(repeats_ids)

# --- Step 4: Results ---
print(f"Unique participants in trials: {len(trials_ids)}")
print(f"Unique participants in repeats: {len(repeats_ids)}")
print(f"Total unique participants across both: {len(all_ids)}")


Unique participants in trials: 51
Unique participants in repeats: 21
Total unique participants across both: 52


In [2]:
# Find overlap
overlap_ids = trials_ids.intersection(repeats_ids)

print(f"Participants in both folders: {len(overlap_ids)}")
print("Overlapping IDs:", overlap_ids)


Participants in both folders: 20
Overlapping IDs: {'071', '057', '054', '056', '047', '032', '006', '049', '042', '062', '036', '058', '041', '009', '029', '016', '011', '065', '021', '048'}


In [4]:
import os

# Path to trials folder
trials_path = "first trials"

# Get unique participant IDs from trials folder
trials_ids = set([fname.split("_")[0] for fname in os.listdir(trials_path) if fname.endswith(".csv")])

print(f"Total unique participants in trials: {len(trials_ids)}")
print("Participant IDs:", trials_ids)


Total unique participants in trials: 51
Participant IDs: {'038', '071', '033', '037', '001', '057', '063', '054', '017', '069', '019', '003', '020', '056', '047', '013', '032', '006', '066', '049', '050', '027', '043', '053', '042', '062', '008', '036', '025', '028', '058', '051', '041', '026', '060', '015', '031', '064', '009', '029', '016', '034', '040', '070', '011', '052', '065', '039', '021', '048', '024'}


In [6]:
import os
import re
from typing import Set, Tuple, List

# ---------------------------
# CONFIG: set your folder paths
# ---------------------------
TRIALS_DIR = "first trials"
REPEATS_DIR = "repeats"

# ---------------------------
# Helper functions
# ---------------------------
PID_TASK_PATTERN = re.compile(r"^(\d+)_([0-3])\.csv$", re.IGNORECASE)

def extract_participant_ids(folder: str) -> Set[str]:
    """
    Extract participant IDs from filenames in a folder.
    Filenames expected as '<participantid>_<taskid>.csv', e.g. '029_2.csv'.
    Returns a set of participant IDs as zero-padded strings (as seen in filenames).
    """
    pids = set()
    if not os.path.isdir(folder):
        print(f"[WARN] Folder not found: {folder}")
        return pids

    for fname in os.listdir(folder):
        if not fname.lower().endswith(".csv"):
            continue
        m = PID_TASK_PATTERN.match(fname)
        if m:
            pid = m.group(1)  # keep as string to preserve leading zeros
            pids.add(pid)
        else:
            # If pattern doesn't match, try a more lenient parse
            # (handles filenames like '029_2_extra.csv' if ever present)
            base = os.path.splitext(fname)[0]
            parts = base.split("_")
            if len(parts) >= 2 and parts[0].isdigit():
                pids.add(parts[0])
            else:
                print(f"[WARN] Skipped (unexpected name): {fname}")
    return pids

def to_int_set(pids: Set[str]) -> Set[int]:
    """Convert string IDs (possibly zero-padded) to integers safely."""
    out = set()
    for pid in pids:
        try:
            out.add(int(pid))
        except ValueError:
            pass
    return out

def pretty_ids(pids: Set[str]) -> List[str]:
    """Return a sorted list of IDs (as strings) with natural numeric ordering."""
    return sorted(pids, key=lambda s: (len(s), int(s)))

def missing_ids(all_pid_ints: Set[int]) -> Tuple[int, List[int]]:
    """Return (max_id, sorted_missing_ids) based on observed integer IDs."""
    if not all_pid_ints:
        return 0, []
    max_id = max(all_pid_ints)
    expected = set(range(1, max_id + 1))
    missing = sorted(expected - all_pid_ints)
    return max_id, missing

def print_header(title: str):
    print("\n" + "=" * 72)
    print(title)
    print("=" * 72)

def print_list(title: str, items: List[str], max_show: int = 40):
    print(f"{title}: {len(items)}")
    if not items:
        return
    if len(items) <= max_show:
        print("  " + ", ".join(items))
    else:
        head = ", ".join(items[:max_show//2])
        tail = ", ".join(items[-max_show//2:])
        print(f"  {head}, ... , {tail}")

# ---------------------------
# Main analysis
# ---------------------------
if __name__ == "__main__":
    # 1) Collect participant IDs from each folder
    trials_pids_str  = extract_participant_ids(TRIALS_DIR)
    repeats_pids_str = extract_participant_ids(REPEATS_DIR)

    # 2) Convert to numeric sets for stats (while keeping string form for display)
    trials_pids_int  = to_int_set(trials_pids_str)
    repeats_pids_int = to_int_set(repeats_pids_str)

    # 3) Unions & intersections
    union_pids_str = trials_pids_str | repeats_pids_str
    inter_pids_str = trials_pids_str & repeats_pids_str
    only_trials_str  = trials_pids_str - repeats_pids_str
    only_repeats_str = repeats_pids_str - trials_pids_str

    # 4) Missing ID analysis (using integer IDs)
    all_pid_ints = trials_pids_int | repeats_pids_int
    max_id, missing = missing_ids(all_pid_ints)

    # 5) Pretty sorted lists for display
    trials_sorted   = pretty_ids(trials_pids_str)
    repeats_sorted  = pretty_ids(repeats_pids_str)
    union_sorted    = pretty_ids(union_pids_str)
    inter_sorted    = pretty_ids(inter_pids_str)
    only_trials_sorted  = pretty_ids(only_trials_str)
    only_repeats_sorted = pretty_ids(only_repeats_str)

    # ---------------------------
    # PRINT REPORT
    # ---------------------------
    print_header("Participant ID Summary (by folder)")
    print_list("Unique participants in TRIALS", trials_sorted)
    print_list("Unique participants in REPEATS", repeats_sorted)

    print_header("Combined Participants")
    print_list("Union (TRIALS ∪ REPEATS)", union_sorted)
    print_list("Overlap (TRIALS ∩ REPEATS)", inter_sorted)
    print_list("Only in TRIALS", only_trials_sorted)
    print_list("Only in REPEATS", only_repeats_sorted)

    print_header("Numeric Consistency Check")
    print(f"Max observed participant ID: {max_id if max_id else 'N/A'}")
    print(f"Total unique participants across both: {len(union_pids_str)}")
    print(f"Count of missing IDs between 1 and {max_id}: {len(missing)}")
    if missing:
        # Show first and last few for readability
        if len(missing) <= 40:
            print("Missing IDs:", ", ".join(map(str, missing)))
        else:
            head = ", ".join(map(str, missing[:20]))
            tail = ", ".join(map(str, missing[-20:]))
            print(f"Missing IDs: {head}, ... , {tail}")

    print("\nTip: A high max ID (e.g., 072) with fewer unique participants (e.g., 52) indicates non-contiguous ID assignment (some IDs unused or absent in the released data).")



Participant ID Summary (by folder)
Unique participants in TRIALS: 51
  001, 003, 006, 008, 009, 011, 013, 015, 016, 017, 019, 020, 021, 024, 025, 026, 027, 028, 029, 031, ... , 047, 048, 049, 050, 051, 052, 053, 054, 056, 057, 058, 060, 062, 063, 064, 065, 066, 069, 070, 071
Unique participants in REPEATS: 21
  006, 009, 011, 016, 021, 029, 032, 036, 041, 042, 047, 048, 049, 054, 056, 057, 058, 062, 065, 071, 072

Combined Participants
Union (TRIALS ∪ REPEATS): 52
  001, 003, 006, 008, 009, 011, 013, 015, 016, 017, 019, 020, 021, 024, 025, 026, 027, 028, 029, 031, ... , 048, 049, 050, 051, 052, 053, 054, 056, 057, 058, 060, 062, 063, 064, 065, 066, 069, 070, 071, 072
Overlap (TRIALS ∩ REPEATS): 20
  006, 009, 011, 016, 021, 029, 032, 036, 041, 042, 047, 048, 049, 054, 056, 057, 058, 062, 065, 071
Only in TRIALS: 31
  001, 003, 008, 013, 015, 017, 019, 020, 024, 025, 026, 027, 028, 031, 033, 034, 037, 038, 039, 040, 043, 050, 051, 052, 053, 060, 063, 064, 066, 069, 070
Only in REPEATS:

In [None]:
%pip install pandas numpy scikit-learn shap matplotlib seaborn joblib tqdm



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [9]:
import types
from shap_explain import run_shap_pipeline

args = types.SimpleNamespace(
    input_csv="phase2_complete_features.csv",
    label_col=None,            # let it auto-derive from file_path or condition
    group_col=None,            # auto-detect participant_id if exists
    drop_cols=None,
    model="rf",                # or "svm", "lr"
    test_size=0.2,
    seed=42,
    output_dir="shap_rf_outputs",
    kernel_bg_samples=200,
    kernel_test_samples=200,
    kernel_nsamples="auto",
    top_dependence=8
)

run_shap_pipeline(args)


ValueError: Could not infer label column. Please pass --label-col.

In [11]:
%pip install -U pip
%pip install pandas numpy scikit-learn shap matplotlib seaborn tqdm


Collecting pip
  Downloading pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 25.1
    Uninstalling pip-25.1:
      Successfully uninstalled pip-25.1
[0mSuccessfully installed pip-25.2
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [12]:
pip install xgboost lightgbm


Collecting xgboost
  Downloading xgboost-3.0.4-py3-none-macosx_10_15_x86_64.whl.metadata (2.1 kB)
Collecting lightgbm
  Downloading lightgbm-4.6.0-py3-none-macosx_10_15_x86_64.whl.metadata (17 kB)
Downloading xgboost-3.0.4-py3-none-macosx_10_15_x86_64.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m4.8 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hDownloading lightgbm-4.6.0-py3-none-macosx_10_15_x86_64.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m5.0 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hInstalling collected packages: xgboost, lightgbm
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [lightgbm]1/2[0m [lightgbm]
[1A[2KSuccessfully installed lightgbm-4.6.0 xgboost-3.0.4
Note: you may need to restart the kernel to use updated packages.


In [13]:
import pandas as pd

df = pd.read_csv("phase2_complete_features.csv")
print(df.shape)
print(df.columns.tolist())
print(df.head(3).T)  # vertical view


(734, 17)
['duration_samples', 'duration_time', 'mean_velocity', 'std_velocity', 'mean_acceleration', 'mean_jerk', 'blink_count', 'blink_rate', 'participant_id', 'trial_number', 'condition', 'trial_type', 'file_path', 'gaze_entropy', 'fixation_count', 'mean_fixation_duration_ms', 'mean_fixation_dispersion']
                                                0                       1  \
duration_samples                             7450                    7450   
duration_time                           46.543963               46.543963   
mean_velocity                            2.118361                2.118361   
std_velocity                             1.529666                1.529666   
mean_acceleration                        0.000438                0.000438   
mean_jerk                                0.000066                0.000066   
blink_count                                     0                       0   
blink_rate                                    0.0                     0.0  

In [17]:
import os
import pandas as pd
from shap_explain_fixed import run_shap_pipeline_fixed

# 1) If needed: create label from file_path (_0/_1 -> 0, _2/_3 -> 1)
df = pd.read_csv("phase2_complete_features.csv")

def derive_label_from_path(path):
    import re, os
    basename = os.path.basename(str(path)).strip()
    m = re.search(r'_(\d)\.csv$', basename)
    if m:
        tid = int(m.group(1))
        return 1 if tid in (2,3) else 0
    return None

if 'label' not in df.columns:
    df['label'] = df['file_path'].apply(derive_label_from_path)
    df = df.dropna(subset=['label'])
    df['label'] = df['label'].astype(int)
    df.to_csv("phase2_complete_features_with_label.csv", index=False)
    print("Saved: phase2_complete_features_with_label.csv")
else:
    print("Label column already present.")

# 2) Run SHAP pipeline (RF, subject-wise split via participant_id)
run_shap_pipeline_fixed(
    input_csv="phase2_complete_features_with_label.csv",
    label_col="label",
    group_col="participant_id",               # prevents identity leakage
    drop_cols=["file_path", "trial_type", "condition"],  # drop meta columns
    model="rf",                               # use RF for SHAP TreeExplainer
    test_size=0.2,
    seed=42,
    output_dir="shap_rf_outputs"
)

Saved: phase2_complete_features_with_label.csv
              precision    recall  f1-score   support

           0     1.0000    1.0000    1.0000        99
           1     1.0000    1.0000    1.0000       102

    accuracy                         1.0000       201
   macro avg     1.0000    1.0000    1.0000       201
weighted avg     1.0000    1.0000    1.0000       201



ValueError: This RandomForestClassifier estimator requires y to be passed, but the target y is None.

In [19]:
def plot_saliency_overlay(time_s, saliency, velocity, blink_flag, out_path):
    import matplotlib.pyplot as plt
    fig, ax1 = plt.subplots(figsize=(11,3.5))

    ax1.plot(time_s, saliency, color='crimson', lw=2, label='Grad-CAM++')
    ax1.set_ylabel('Saliency', color='crimson')
    ax1.set_xlabel('Time (sec)')
    ax1.tick_params(axis='y', labelcolor='crimson')

    ax2 = ax1.twinx()
    ax2.plot(time_s, velocity, color='navy', alpha=0.6, label='Velocity')
    ax2.fill_between(time_s, 0, blink_flag, color='gold', alpha=0.3, label='Blink')
    ax2.set_ylabel('Velocity / Blink')

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1+lines2, labels1+labels2, loc='upper right')

    plt.title("Temporal saliency aligned with dynamic signals")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()


In [20]:
def plot_classwise_mean_saliency(saliency_class0, saliency_class1, out_path):
    import matplotlib.pyplot as plt
    T = len(saliency_class0)
    t = np.arange(T) / 60.0  # seconds
    plt.figure(figsize=(10,3))
    plt.plot(t, saliency_class0, label='Low Load', color='teal', lw=2)
    plt.plot(t, saliency_class1, label='High Load', color='crimson', lw=2)
    plt.xlabel("Time (sec)")
    plt.ylabel("Mean saliency")
    plt.title("Per-class average temporal saliency (Grad-CAM++)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()


In [23]:
from graphviz import Digraph

def esc(text: str) -> str:
    """Escape special characters for Graphviz HTML-like labels."""
    return (text
            .replace('&', '&amp;')
            .replace('<', '&lt;')
            .replace('>', '&gt;'))

def _box(label_title: str, label_sub: str) -> str:
    """
    Build an HTML-like label with bold title and smaller subtitle on the next line.
    Graphviz renders this crisply in PDF/SVG.
    """
    return f'''<
      <B>{esc(label_title)}</B><BR ALIGN="CENTER"/>
      <FONT POINT-SIZE="14">{esc(label_sub)}</FONT>
    >'''


def build_pipeline_report(
    filename="fig_feature_pipeline_proper_large",
    pdf_size="10,6!",      # target size in inches (width,height) for PDF/SVG
    png_dpi="600"          # DPI for raster export
):
    # Use DOT engine and start with vector (PDF/SVG)
    g = Digraph(engine="dot", filename=filename, format="pdf")

    # -----------------------------
    # Global graph attributes
    # -----------------------------
    g.attr(
        "graph",
        rankdir="TB",           # overall top-to-bottom direction (tiled into two rows)
        bgcolor="white",
        margin="0.2",
        pad="0.2",
        nodesep="0.65",         # spacing between nodes in same rank
        ranksep="0.8",          # spacing between ranks (rows)
        splines="ortho",        # orthogonal edges
        concentrate="true",     # merge parallel edges where possible
        outputorder="edgesfirst",
        size=pdf_size,          # constrain overall figure size (for PDF/SVG)
        ratio="compress"        # compress whitespace
    )

    g.attr(
        "node",
        shape="box",
        style="rounded,filled",
        color="#444444",
        fillcolor="#F7F9FC",    # subtle fill
        fontname="Helvetica",
        fontsize="18",          # good for print while not overwhelming
        margin="0.20,0.12",     # inner padding within node
        penwidth="1.6"
    )

    g.attr(
        "edge",
        color="#555555",
        penwidth="1.6",
        arrowsize="0.9"
    )

    # -----------------------------
    # Nodes (two-row layout)
    # -----------------------------
    # Top row
    g.node("raw",     label=_box("Raw Eye-Tracking Data", "Fixations · Saccades · Blinks"))
    g.node("preproc", label=_box("Preprocessing", "Interpolation · Smoothing · Blink Handling"))
    g.node("segment", label=_box("Temporal Segmentation", "2s Windows (Non-Overlapping)"))
    g.node("features",label=_box("Feature Extraction", "Velocity · Entropy · Blink Rate"))

    # Bottom row
    g.node("fusion",  label=_box("Feature Fusion &amp; Cleaning", "Merge · Normalize · Deduplicate"))
    g.node("csvs",    label=_box("Structured CSV Exports", "Trial-Level &amp; Epoch-Level"))
    g.node("models",  label=_box("Downstream Models", "Classical ML · Deep Learning"))

    # -----------------------------
    # Rank constraints (tile into two rows)
    # -----------------------------
    with g.subgraph() as top:
        top.attr(rank="same")
        top.node("raw")
        top.node("preproc")
        top.node("segment")
        top.node("features")

    with g.subgraph() as bottom:
        bottom.attr(rank="same")
        bottom.node("fusion")
        bottom.node("csvs")
        bottom.node("models")

    # -----------------------------
    # Primary flow edges
    # -----------------------------
    # Top row connections
    g.edge("raw", "preproc")
    g.edge("preproc", "segment")
    g.edge("segment", "features")

    # Vertical flow from top to bottom row
    g.edge("features", "fusion")

    # Bottom row connections
    g.edge("fusion", "csvs")
    g.edge("csvs", "models")

    # Invisible edges to keep columns aligned top/bottom
    g.edge("preproc", "csvs", style="invis")
    g.edge("segment", "fusion", style="invis")

    # -----------------------------
    # Render vector formats first (PDF + SVG)
    # -----------------------------
    g.render(cleanup=True)   # PDF
    g.format = "svg"
    g.render(cleanup=True)   # SVG (great for web and vector editing)

    # -----------------------------
    # Render high-DPI PNG for screens/slides
    # -----------------------------
    g.format = "png"
    # IMPORTANT: set PNG DPI right before rendering raster output
    g.attr("graph", dpi=str(png_dpi))
    g.render(cleanup=True)

    return f"Generated {filename}.pdf, {filename}.svg, and {filename}.png (at {png_dpi} DPI)"


if __name__ == "__main__":
    print(build_pipeline_report())


Generated fig_feature_pipeline_proper_large.pdf, fig_feature_pipeline_proper_large.svg, and fig_feature_pipeline_proper_large.png (at 600 DPI)
