# Training MLP Model

This notebook trains and evaluates the **feature-based MLP** using the repo that we wrote:
- `src/utils/data_loader.py` (MOABB loader)
- `src/pipeline/feature_pipeline.py` (CSP + PSD + Asym + StandardScaler)
- `src/models/mlp.py` (classifier)
- `testing/train_feature_mlp_with_pipeline.py` (trainer)
- `testing/eval_feature_mlp.py` (validator)

It supports two flows:
1. **Notebook is inside the repo root** (recommended): it just runs.
2. **Clone the repo**: set `REPO_URL` below.

Repo: https://github.com/Alberta-Bionix-natHacks-2025/natHacks2025.git


In [1]:
# Global switches for Run all
RUN_SETUP_CELLS = False

In [2]:
#@title Install dependencies (Colab)
import sys, subprocess
def pipi(pkgs):
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q'] + pkgs)

# Core
pipi(['mne', 'moabb', 'braindecode', 'scikit-learn', 'joblib'])
try:
    import torch
except Exception:
    pipi(['torch'])
print('Dependencies installed.')


Dependencies installed.


In [4]:
#@title Clone / detect repo and choose a branch (handles dirty weights safely)
import os, sys, subprocess, shutil, time
from pathlib import Path

REPO_URL = "https://github.com/Alberta-Bionix-natHacks-2025/natHacks2025"  #@param {type:"string"}
REPO_URL = REPO_URL.strip()
REPO_NAME = REPO_URL.rstrip("/").split("/")[-1].removesuffix(".git") if REPO_URL else None

def _run(cmd):
    return subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

def _ensure_ipywidgets():
    try:
        import ipywidgets  # noqa
        return True
    except Exception:
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "ipywidgets"])
            import ipywidgets  # noqa
            return True
        except Exception as e:
            print("Could not install ipywidgets:", e)
            return False

def _git_root():
    r = _run(["git", "rev-parse", "--show-toplevel"])
    return Path(r.stdout.strip()) if r.returncode == 0 else None

def _safe_clone(url: str):
    if not url: return False
    if _git_root():
        print("Already inside a git repo; skipping clone.")
        return True
    repo_name = url.rstrip("/").split("/")[-1].removesuffix(".git")
    if Path(repo_name, ".git").exists():
        print(f"Using existing clone: {repo_name}")
        os.chdir(repo_name)
        return True
    print("Cloning", url)
    r = _run(["git", "clone", "--depth", "1", url])
    if r.returncode != 0:
        print("Clone failed:\n", r.stderr)
        return False
    os.chdir(repo_name)
    return True

def _current_branch():
    r = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
    return r.stdout.strip() if r.returncode == 0 else "unknown"

def _dirty_files():
    r = _run(["git", "status", "--porcelain"])
    files = []
    if r.returncode == 0:
        for ln in r.stdout.splitlines():
            ln = ln.strip()
            if not ln: continue
            # format: "XY path"
            parts = ln.split(maxsplit=1)
            if len(parts) == 2:
                files.append(parts[1])
    return files

def _weights_dirty():
    return [f for f in _dirty_files() if f.startswith("data/weights/")]

def _stash_weights():
    # stash only data/weights (tracked + untracked)
    r = _run(["git", "stash", "push", "-u", "--", "data/weights"])
    return r.returncode == 0, r.stdout + r.stderr

# ---- locate repo or clone ----
git_root = _git_root()
if git_root:
    os.chdir(git_root)
    print(f"Detected repo root at: {git_root}")
elif REPO_URL:
    ok = _safe_clone(REPO_URL)
    if not ok:
        raise SystemExit("Clone failed. Fix REPO_URL and re-run this cell.")
else:
    if REPO_NAME and Path(REPO_NAME).exists():
        os.chdir(REPO_NAME)
        print(f"Using folder: {Path.cwd()}")
    else:
        raise SystemExit("Place this notebook in your repo root OR set REPO_URL to clone it.")

ROOT = Path(".").resolve()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))
print("Repo path in sys.path ✅", ROOT)

# ---- branch dropdown & checkout with strategies ----
if Path(".git").exists():
    # Ensure origin remote
    if _run(["git", "remote", "get-url", "origin"]).returncode != 0 and REPO_URL:
        _run(["git", "remote", "add", "origin", REPO_URL])

    _run(["git", "fetch", "origin", "--prune"])
    ls = _run(["git", "ls-remote", "--heads", "origin"])
    if ls.returncode != 0:
        print("ls-remote failed:\n", ls.stderr)
    else:
        branches = sorted({
            ln.split("\trefs/heads/")[1]
            for ln in ls.stdout.splitlines() if "\trefs/heads/" in ln
        })
        default = "main" if "main" in branches else ("master" if "master" in branches else (branches[0] if branches else ""))

        if _ensure_ipywidgets():
            import ipywidgets as widgets
            from IPython.display import display, clear_output

            dd = widgets.Dropdown(options=branches, value=default, description="Branch:")
            strategy = widgets.RadioButtons(
                options=[
                    ("Abort if dirty (safe default)", "abort"),
                    ("Stash data/weights then checkout", "stash"),
                    ("Force checkout (discard ALL local changes)", "force"),
                ],
                value="abort",
                description="On local changes:",
                layout=widgets.Layout(width="auto")
            )
            btn = widgets.Button(description="Checkout", button_style="primary")
            popbtn = widgets.Button(description="Pop last stash", tooltip="git stash pop")
            out = widgets.Output()

            def show_status():
                with out:
                    clear_output()
                    print(f"Current branch: { _current_branch() }")
                    dirty = _dirty_files()
                    if dirty:
                        print("Dirty files:")
                        for f in dirty:
                            print("  -", f)
                    else:
                        print("Working tree clean.")

            def checkout(_):
                with out:
                    clear_output()
                    br = dd.value
                    strat = strategy.value
                    print(f"Selected branch: {br}")
                    print(f"Strategy: {strat}")

                    weights_dirty = _weights_dirty()
                    all_dirty = _dirty_files()

                    if strat == "abort" and all_dirty:
                        print("\n❌ Local changes present. Aborting checkout.\n")
                        if weights_dirty:
                            print("Dirty under data/weights/:")
                            for f in weights_dirty: print("  -", f)
                        else:
                            print("Dirty files (outside data/weights):")
                            for f in all_dirty: print("  -", f)
                        print("\nChoose a different strategy if you really want to switch.")
                        return

                    if strat == "stash" and weights_dirty:
                        ok, msg = _stash_weights()
                        print("Stash data/weights:", "OK" if ok else "FAILED")
                        if msg.strip(): print(msg.strip())
                    elif strat == "force":
                        print("⚠️ Force checkout will discard ALL local changes.")
                        # --force + -B to reset local branch to remote
                        _ = _run(["git", "reset", "--hard"])

                    print(f"\nFetching origin/{br} (shallow)…")
                    f = _run(["git", "fetch", "origin", f"+refs/heads/{br}:refs/remotes/origin/{br}", "--depth=1"])
                    if f.returncode != 0:
                        print("fetch failed:\n", f.stderr); return

                    print(f"Checking out {br} …")
                    co = _run(["git", "checkout", "-B", br, f"origin/{br}"])
                    if co.returncode != 0:
                        print("checkout failed:\n", co.stderr); return
                    print(co.stdout or "Checked out.")
                    print("HEAD:")
                    print(_run(["git", "--no-pager", "log", "--oneline", "-n", "1"]).stdout)

            def pop_stash(_):
                with out:
                    clear_output()
                    r = _run(["git", "stash", "pop"])
                    print(r.stdout or r.stderr or "No stash entries.")

            btn.on_click(checkout)
            popbtn.on_click(pop_stash)
            display(widgets.VBox([
                widgets.HBox([dd]),
                strategy,
                widgets.HBox([btn, popbtn]),
                out
            ]))
            show_status()
            print("Pick a branch, choose a strategy, then click Checkout.")
        else:
            print("ipywidgets not available; run: pip install ipywidgets")
else:
    print("Not a git repo; branch dropdown not available.")


Detected repo root at: /content/natHacks2025
Repo path in sys.path ✅ /content/natHacks2025


VBox(children=(HBox(children=(Dropdown(description='Branch:', index=8, options=('AccessingDataFromOpenBCI', 'F…

Pick a branch, choose a strategy, then click Checkout.


Downloaded the wrong repo?? **Wanna start fresh?**



In [20]:
#@title cd to /content  (OFF by default)
RUN_THIS = True  #@param {type:"boolean"}
if RUN_THIS:
    import os
    if os.getcwd() != "/content":
        os.chdir("/content")
        print("cd /content")
else:
    print("Skipped.")


cd /content


In [18]:
#@title Route MOABB/MNE downloads into the repo
import os
import mne, moabb
RAW_DIR = os.path.abspath('data/eeg/raw')
os.makedirs(RAW_DIR, exist_ok=True)
mne.set_config('MNE_DATA', RAW_DIR, set_env=True)
moabb.set_download_dir(RAW_DIR)
print('MNE_DATA =', mne.get_config('MNE_DATA'))
print('MOABB download dir set. ✅')


MNE_DATA = /content/natHacks2025/data/eeg/raw
MOABB download dir set. ✅


In [13]:
#@title Train with pipeline (saves weights + feature pipeline)
import subprocess, sys
print('Running trainer: testing/train_feature_with_pipeline_test_mlp.py')
subprocess.check_call([sys.executable, '-m', 'testing.train_feature_with_pipeline_test_mlp', "--fourch"])
print('\nTraining completed. Files in data/weights')


Running trainer: testing/train_feature_with_pipeline_test_mlp.py

Training completed. Files in data/weights


In [14]:
#@title Evaluate saved model
print('Running evaluator: testing/eval_test_mlp.py')
res = subprocess.run(
    [sys.executable, "-u", "-m", "testing.eval_test_mlp_4ch"],  # -u = unbuffered stdout
    text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
print(res.stdout)
print("Exit code:", res.returncode)


Running evaluator: testing/eval_test_mlp.py
[4ch] Loaded 288 trials | chans=['C4', 'FC2', 'FC1', 'C3'] | samples=500
Val accuracy: 0.8793103448275862
Confusion matrix:
 [[24  5]
 [ 2 27]]

Classification report:
               precision    recall  f1-score   support

        left       0.92      0.83      0.87        29
       right       0.84      0.93      0.89        29

    accuracy                           0.88        58
   macro avg       0.88      0.88      0.88        58
weighted avg       0.88      0.88      0.88        58


Exit code: 0


In [15]:
#@title (Optional) Download weights/pipeline
from pathlib import Path
weights = Path('data/weights/feature_mlp_pipeline_4ch.pth')
pipe = Path('data/weights/feature_pipeline_4ch.joblib')
print('Weights exist:', weights.exists(), weights)
print('Pipeline exists:', pipe.exists(), pipe)
try:
    from google.colab import files
    if weights.exists(): files.download(str(weights))
    if pipe.exists(): files.download(str(pipe))
except Exception as e:
    print('Not in Colab download context or download blocked:', e)


Weights exist: True data/weights/feature_mlp_pipeline_4ch.pth
Pipeline exists: True data/weights/feature_pipeline_4ch.joblib


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Notes
- To **push weights back to GitHub**, you can use a Personal Access Token and `git push` from Colab, or download and commit locally.
- For **real-time**, feed `(C, T)` buffers from your Ganglion/BrainFlow stream into `RealTimeClassifier.predict(...)`.
