In [None]:
from pathlib import Path
import os, sys, yaml, time, glob, shutil
import numpy as np
from IPython.display import display, Markdown, Image

# Point to the root of your repository (where train_pipeline.py lives)
PROJECT_ROOT = Path.cwd()  # <- change this if needed, e.g. Path("/home/giuseppe_bonomo/RatioWaveNet2")
if not (PROJECT_ROOT / "train_pipeline.py").exists():
    print("Couldn't find train_pipeline.py in current directory. Trying parent directory...")
    if (PROJECT_ROOT.parent / "train_pipeline.py").exists():
        PROJECT_ROOT = PROJECT_ROOT.parent
    else:
        print("Please set PROJECT_ROOT to the folder containing train_pipeline.py")
        raise FileNotFoundError("train_pipeline.py not found")

sys.path.insert(0, str(PROJECT_ROOT))

# Light sanity print
print("PROJECT_ROOT:", PROJECT_ROOT)
print("Available configs:", [p.name for p in (PROJECT_ROOT / "configs").glob("*.yaml")])

In [None]:

# --- Choose your experiment parameters ---
# Model name must match a YAML in configs/, e.g. "tcformer", "eegnet", etc.
MODEL_NAME   = "ratiowavenet"  
DATASET      = "bcic2a"        # options in your code: "bcic2a", "bcic2b", "hgd"
USE_LOSO     = False           # True = subject-independent; False = intra-subject
GPU_ID       = 0               # use -1 for CPU in your script's logic (mapped to DDP auto when -1)
INTERAUG     = True            # True / False / None (None uses YAML default)
SEED_LIST    = [0,1,2,3,4]     # list of seeds to run (e.g. [0,1,2,3,4]); or None to use YAML default
SEED_SLEEP_S = 2    
VERBOSE_TRAIN= True            # True -> show Lightning progress bar; False -> print only final summary
PLOT_CM_PER_SUBJECT = True     # per-subject confusion matrices
PLOT_CM_AVERAGE     = True     # average confusion matrix

# --- Subject selection mode ---
# "all" -> use all subjects;
# "one" -> set ONE_SUBJECT below;
# "list"-> set SUBJECT_LIST below.
SUBJECT_MODE = "one"           # "all" | "one" | "list"
ONE_SUBJECT  = 1               # used when SUBJECT_MODE == "one"
SUBJECT_LIST = [1, 2, 3]       # used when SUBJECT_MODE == "list"

# --- Optional overrides (if you want to tweak runs quickly) ---
OVERRIDE_MAX_EPOCHS = None     # set an int to override epochs (e.g., 50); or None to keep YAML/logic
SAVE_CHECKPOINTS    = False    # if you want to toggle saving checkpoints


In [None]:

# This mirrors the logic inside run(), but driven by notebook variables.
from train_pipeline import train_and_test
from types import SimpleNamespace

CONFIG_DIR = PROJECT_ROOT / "configs"
config_path = CONFIG_DIR / f"{MODEL_NAME}.yaml"
assert config_path.exists(), f"Config not found: {config_path}"

with open(config_path) as f:
    config = yaml.safe_load(f)

# --- Apply LOSO / intra logic for dataset & epochs ---
if USE_LOSO:
    config["dataset_name"] = f"{DATASET}_loso"
    config["max_epochs"]   = config["max_epochs_loso_hgd"] if DATASET == "hgd" else config["max_epochs_loso"]
    # Warmup override for LOSO, if present
    if "model_kwargs" in config and "warmup_epochs_loso" in config["model_kwargs"]:
        config["model_kwargs"]["warmup_epochs"] = config["model_kwargs"]["warmup_epochs_loso"]
else:
    config["dataset_name"] = DATASET
    config["max_epochs"]   = config["max_epochs_2b"] if DATASET == "bcic2b" else config["max_epochs"]

# --- Preprocessing per-dataset ---
config["preprocessing"] = config["preprocessing"][DATASET]
config["preprocessing"]["z_scale"] = config.get("z_scale", config["preprocessing"].get("z_scale", False))

# --- Inter-trial augmentation override ---
if INTERAUG is True:
    config["preprocessing"]["interaug"] = True
elif INTERAUG is False:
    config["preprocessing"]["interaug"] = False
else:
    # use the top-level default "interaug" if present
    config["preprocessing"]["interaug"] = config.get("interaug", config["preprocessing"].get("interaug", False))
# remove top-level for cleanliness
config.pop("interaug", None)

# --- Subject selection ---
if SUBJECT_MODE == "all":
    config["subject_ids"] = "all"
elif SUBJECT_MODE == "one":
    config["subject_ids"] = ONE_SUBJECT
elif SUBJECT_MODE == "list":
    config["subject_ids"] = SUBJECT_LIST
else:
    raise ValueError("SUBJECT_MODE must be one of: 'all', 'one', 'list'")

# --- Seed & GPU ---
config["gpu_id"] = GPU_ID
if SEED is not None:
    config["seed"] = SEED

# --- Optional epoch override ---
if OVERRIDE_MAX_EPOCHS is not None:
    config["max_epochs"] = int(OVERRIDE_MAX_EPOCHS)

# --- Plotting toggles ---
config["plot_cm_per_subject"] = bool(PLOT_CM_PER_SUBJECT)
config["plot_cm_average"]     = bool(PLOT_CM_AVERAGE)

# --- Optional checkpoint saving ---
if SAVE_CHECKPOINTS:
    config["save_checkpoint"] = True

# --- Verbosity control (Lightning progress bar / logs) ---
# Your train_pipeline sets logger=False. We can reduce noise further by disabling the progress bar.
# We'll use a simple flag that your code will read via environment variable.
os.environ["PL_TRAIN_PROGRESS_BAR"] = "1" if VERBOSE_TRAIN else "0"

display(Markdown("### Config ready"))
print(yaml.dump(config, sort_keys=False))


In [None]:

# We dynamically patch the Trainer kwargs to respect our VERBOSE_TRAIN flag.
# This avoids editing your file on disk. If you prefer, you can permanently add:
#   enable_progress_bar=bool(os.environ.get("PL_TRAIN_PROGRESS_BAR","1")=="1")
#
# Below is a context manager to monkey-patch pytorch_lightning.Trainer.__init__ once per session.
import pytorch_lightning as pl
from functools import wraps
from pytorch_lightning.trainer.trainer import Trainer as _PLTrainer

if not hasattr(_PLTrainer, "_patched_progress_bar"):
    _orig_init = _PLTrainer.__init__

    @wraps(_orig_init)
    def _wrapped_init(self, *args, **kwargs):
        # Inject our toggle unless user already set it
        if "enable_progress_bar" not in kwargs:
            kwargs["enable_progress_bar"] = (os.environ.get("PL_TRAIN_PROGRESS_BAR","1") == "1")
        if "enable_model_summary" not in kwargs:
            kwargs["enable_model_summary"] = False  # less noise in notebooks
        if "log_every_n_steps" not in kwargs:
            kwargs["log_every_n_steps"] = 1
        return _orig_init(self, *args, **kwargs)

    _PLTrainer.__init__ = _wrapped_init
    _PLTrainer._patched_progress_bar = True
    print("Patched PyTorch Lightning Trainer to honor VERBOSE_TRAIN toggle.")
else:
    print("Trainer already patched in this session.")


In [None]:
# --- RUN ---
start = time.time()
print("Starting training & testing...")
train_and_test(config)
elapsed = time.time() - start
print(f"Done in {elapsed/60:.1f} min")

In [None]:

# Try to locate the most recent results folder produced by this run and preview assets.
from datetime import datetime

results_root = PROJECT_ROOT / "results"
if not results_root.exists():
    print("No results folder found yet.")
else:
    # Pick the latest folder matching current model+dataset
    candidates = sorted(results_root.glob(f"{config['model']}_{config['dataset_name']}_seed-*"), key=os.path.getmtime)
    if not candidates:
        # your script names result folders with f"{model_name}_{dataset_name}_seed-..."
        # but in this notebook, config['model'] may live under config['model_kwargs'] in your YAML.
        # Fall back to a looser glob if strict one misses:
        candidates = sorted(results_root.glob(f"*{config['dataset_name']}*"), key=os.path.getmtime)

    if not candidates:
        print("No matching result directories found.")
    else:
        latest = candidates[-1]
        print("Latest result dir:", latest)

        # Show confusion matrices found
        cm_dir = latest / "confmats"
        if cm_dir.exists():
            for img_path in sorted(cm_dir.glob("*.png"))[:12]:
                display(Markdown(f"**{img_path.name}**"))
                display(Image(filename=str(img_path)))
        else:
            print("No 'confmats' folder found.")

        # Show curves if any
        curves_dir = latest / "curves"
        if curves_dir.exists():
            for img_path in sorted(curves_dir.glob("*.png"))[:12]:
                display(Markdown(f"**{img_path.name}**"))
                display(Image(filename=str(img_path)))
        else:
            print("No 'curves' folder found.")

        # Try to display a summary text/CSV if your writer created one
        # (adjust filenames if your write_summary uses a specific pattern)
        summary_txt = list(latest.glob("summary*.txt")) + list(latest.glob("summary*.csv"))
        for s in summary_txt:
            display(Markdown(f"### {s.name}"))
            try:
                print(s.read_text())
            except Exception as e:
                print("Couldn't read summary file:", e)
