# Training MLP Model

This notebook trains and evaluates the **feature-based MLP** using the repo that we wrote:
- `src/utils/data_loader.py` (MOABB loader)
- `src/pipeline/feature_pipeline.py` (CSP + PSD + Asym + StandardScaler)
- `src/models/mlp.py` (classifier)
- `testing/train_feature_mlp_with_pipeline.py` (trainer)
- `testing/eval_feature_mlp.py` (validator)

It supports two flows:
1. **Notebook is inside the repo root** (recommended): it just runs.
2. **Clone the repo**: set `REPO_URL` below.

Repo: https://github.com/Alberta-Bionix-natHacks-2025/natHacks2025.git


In [15]:
# Global switches for Run all
RUN_SETUP_CELLS = False

In [16]:
#@title Install dependencies (Colab)
import sys, subprocess
def pipi(pkgs):
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q'] + pkgs)

# Core
pipi(['mne', 'moabb', 'braindecode', 'scikit-learn', 'joblib'])
try:
    import torch
except Exception:
    pipi(['torch'])
print('Dependencies installed.')


Dependencies installed.


In [2]:
#@title Clone / detect repo and choose a branch (no nesting; lists ALL remote branches)
import os, sys, subprocess
from pathlib import Path

REPO_URL = "https://github.com/Alberta-Bionix-natHacks-2025/natHacks2025"  #@param {type:"string"}
REPO_URL = REPO_URL.strip()
REPO_NAME = REPO_URL.rstrip("/").split("/")[-1].removesuffix(".git") if REPO_URL else None

def _run(cmd):
    return subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

def _ensure_ipywidgets():
    try:
        import ipywidgets  # noqa
        return True
    except Exception:
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "ipywidgets"])
            return True
        except Exception as e:
            print("Could not install ipywidgets:", e)
            return False

def _git_root():
    r = _run(["git", "rev-parse", "--show-toplevel"])
    if r.returncode == 0:
        return Path(r.stdout.strip())
    return None

def _safe_clone(url: str):
    if not url:
        return False
    # if we're already inside a repo, don't clone
    if _git_root():
        print("Already inside a git repo; skipping clone.")
        return True
    repo_name = url.rstrip("/").split("/")[-1].removesuffix(".git")
    # if a sibling folder exists with a git repo, use it
    if Path(repo_name, ".git").exists():
        print(f"Using existing clone: {repo_name}")
        os.chdir(repo_name)
        return True
    print("Cloning", url)
    r = _run(["git", "clone", "--depth", "1", url])
    if r.returncode != 0:
        print("Clone failed:\n", r.stderr)
        return False
    os.chdir(repo_name)
    return True

# -------- locate or clone repo without nesting --------
git_root = _git_root()
if git_root:
    # ensure we are at the repo top-level (not a nested subdir)
    os.chdir(git_root)
    print(f"Detected repo root at: {git_root}")
elif REPO_URL:
    ok = _safe_clone(REPO_URL)
    if not ok:
        raise SystemExit("Clone failed. Fix REPO_URL and re-run this cell.")
else:
    # fallback: if a top-level folder exists, enter it; otherwise error
    if REPO_NAME and Path(REPO_NAME).exists():
        os.chdir(REPO_NAME)
        print(f"Using folder: {Path.cwd()}")
    else:
        raise SystemExit("Place this notebook in your repo root OR set REPO_URL to clone it.")

ROOT = Path(".").resolve()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))
print("Repo path in sys.path ✅", ROOT)

# -------- branch dropdown: list ALL remote heads, then fetch+checkout --------
if Path(".git").exists():
    # make sure 'origin' exists (use REPO_URL if needed)
    if _run(["git", "remote", "get-url", "origin"]).returncode != 0 and REPO_URL:
        _run(["git", "remote", "add", "origin", REPO_URL])

    ls = _run(["git", "ls-remote", "--heads", "origin"])
    if ls.returncode != 0:
        print("ls-remote failed:\n", ls.stderr)
    else:
        branches = sorted({
            ln.split("\trefs/heads/")[1]
            for ln in ls.stdout.splitlines() if "\trefs/heads/" in ln
        })
        default = "main" if "main" in branches else ("master" if "master" in branches else (branches[0] if branches else ""))

        if _ensure_ipywidgets():
            import ipywidgets as widgets
            from IPython.display import display, clear_output

            dd = widgets.Dropdown(options=branches, value=default, description="Branch:")
            btn = widgets.Button(description="Checkout", button_style="primary")
            refbtn = widgets.Button(description="Refresh")
            out = widgets.Output()

            def refresh(_=None):
                with out:
                    clear_output()
                    ls2 = _run(["git", "ls-remote", "--heads", "origin"])
                    if ls2.returncode != 0:
                        print("ls-remote failed:\n", ls2.stderr); return
                    new = sorted({
                        ln.split("\trefs/heads/")[1]
                        for ln in ls2.stdout.splitlines() if "\trefs/heads/" in ln
                    })
                    dd.options = new
                    if dd.value not in new and new:
                        dd.value = "main" if "main" in new else ("master" if "master" in new else new[0])
                    print("Branches:", new)

            def checkout(_):
                with out:
                    clear_output()
                    br = dd.value
                    if not br:
                        print("No branch selected."); return
                    print(f"Fetching origin/{br} (shallow)…")
                    f = _run(["git", "fetch", "origin", f"+refs/heads/{br}:refs/remotes/origin/{br}", "--depth=1"])
                    if f.returncode != 0:
                        print("fetch failed:\n", f.stderr); return
                    print(f"Checking out {br} …")
                    co = _run(["git", "checkout", "-B", br, f"origin/{br}"])
                    if co.returncode != 0:
                        print("checkout failed:\n", co.stderr); return
                    print(co.stdout or "Checked out.")
                    print("HEAD:")
                    print(_run(["git", "--no-pager", "log", "--oneline", "-n", "1"]).stdout)

            btn.on_click(checkout)
            refbtn.on_click(refresh)
            display(widgets.HBox([dd, btn, refbtn]), out)
            print("Pick a branch and click Checkout.")
        else:
            print("ipywidgets not available; run: pip install ipywidgets")
else:
    print("Not a git repo; branch dropdown not available.")


Cloning https://github.com/Alberta-Bionix-natHacks-2025/natHacks2025
Repo path in sys.path ✅ /content/natHacks2025


HBox(children=(Dropdown(description='Branch:', index=6, options=('AccessingDataFromOpenBCI', 'FineTuningModel'…

Output()

Pick a branch and click Checkout.


Downloaded the wrong repo?? **Wanna start fresh?**



In [17]:
#@title cd to /content  (OFF by default)
RUN_THIS = False  #@param {type:"boolean"}
if RUN_THIS:
    import os
    if os.getcwd() != "/content":
        os.chdir("/content")
        print("cd /content")
else:
    print("Skipped.")


Skipped.


In [18]:
#@title Route MOABB/MNE downloads into the repo
import os
import mne, moabb
RAW_DIR = os.path.abspath('data/eeg/raw')
os.makedirs(RAW_DIR, exist_ok=True)
mne.set_config('MNE_DATA', RAW_DIR, set_env=True)
moabb.set_download_dir(RAW_DIR)
print('MNE_DATA =', mne.get_config('MNE_DATA'))
print('MOABB download dir set. ✅')


MNE_DATA = /content/natHacks2025/data/eeg/raw
MOABB download dir set. ✅


In [19]:
#@title Train with pipeline (saves weights + feature pipeline)
import subprocess, sys
print('Running trainer: testing/train_feature_with_pipeline_test_mlp.py')
subprocess.check_call([sys.executable, '-m', 'testing.train_feature_with_pipeline_test_mlp'])
print('\nTraining completed. Files in data/weights')


Running trainer: testing/train_feature_with_pipeline_test_mlp.py

Training completed. Files in data/weights


In [21]:
#@title Evaluate saved model
print('Running evaluator: testing/eval_test_mlp.py')
res = subprocess.run(
    [sys.executable, "-u", "-m", "testing.eval_test_mlp"],  # -u = unbuffered stdout
    text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
print(res.stdout)
print("Exit code:", res.returncode)


Running evaluator: testing/eval_test_mlp.py
Loading BNCI 2014-001 via MOABB (Left/Right, 2s, 8–30 Hz)...
Choosing from all possible events
Trials after LR filter: 288, Channels: 22, Samples: 500
Class distribution: {0: 144, 1: 144}
Val accuracy: 0.9138
Confusion matrix:
 [[27  2]
 [ 3 26]]

Classification report:
               precision    recall  f1-score   support

        left       0.90      0.93      0.92        29
       right       0.93      0.90      0.91        29

    accuracy                           0.91        58
   macro avg       0.91      0.91      0.91        58
weighted avg       0.91      0.91      0.91        58


Exit code: 0


In [22]:
#@title (Optional) Download weights/pipeline
from pathlib import Path
weights = Path('data/weights/feature_mlp_pipeline.pth')
pipe = Path('data/weights/feature_pipeline.joblib')
print('Weights exist:', weights.exists(), weights)
print('Pipeline exists:', pipe.exists(), pipe)
try:
    from google.colab import files
    if weights.exists(): files.download(str(weights))
    if pipe.exists(): files.download(str(pipe))
except Exception as e:
    print('Not in Colab download context or download blocked:', e)


Weights exist: True data/weights/feature_mlp_pipeline.pth
Pipeline exists: True data/weights/feature_pipeline.joblib


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [23]:
#@title Quick inference demo (single window)
import numpy as np
from src.realtime.inference import RealTimeClassifier
from src.pipeline.feature_pipeline import FeaturePipeline

# Load pipeline to get feature_dim
fp = FeaturePipeline.load('data/weights/feature_pipeline.joblib')
input_dim = fp.feature_dim
print('Feature dim:', input_dim)

rtc = RealTimeClassifier(
    pipeline_path='data/weights/feature_pipeline.joblib',
    weights_path='data/weights/feature_mlp_pipeline.pth',
    input_dim=input_dim,
    n_classes=2,
    device='cuda' if False else 'cpu',
)

# Dummy window (22 chans x 500 samples). Replace with real-time buffer later.
x = np.random.randn(22, 500)
probs = rtc.predict_proba(x)
print('probs:', probs)
print('pred:', rtc.predict(x))


Feature dim: 50
probs: [[0.01188426 0.9881158 ]]
pred: [1]


## Notes
- To **push weights back to GitHub**, you can use a Personal Access Token and `git push` from Colab, or download and commit locally.
- For **real-time**, feed `(C, T)` buffers from your Ganglion/BrainFlow stream into `RealTimeClassifier.predict(...)`.
