In [1]:
import sys, subprocess, pathlib, re

def sh(cmd):
    print(">>", cmd)
    subprocess.check_call(cmd, shell=True)

sh(f"{sys.executable} -m pip -q install --upgrade pip setuptools wheel")
sh(f"{sys.executable} -m pip -q install synapseclient scikit-learn matplotlib tqdm")
sh(f"{sys.executable} -m pip -q install mne mne-connectivity xmltodict numpy scipy pandas joblib")
sh(f"{sys.executable} -m pip -q uninstall -y torcheeg || true")
sh("rm -rf torcheeg_src")
sh("git clone --depth 1 --branch v1.1.3 https://github.com/torcheeg/torcheeg.git torcheeg_src")

setup_py = pathlib.Path("torcheeg_src/setup.py")
txt = setup_py.read_text()

# Remove scipy<=1.10.1 constraint
txt2 = re.sub(r"scipy>=1\.7\.3\s*,\s*<=\s*1\.10\.1", "scipy>=1.7.3", txt)
setup_py.write_text(txt2)

# now install with deps
sh(f"{sys.executable} -m pip -q install ./torcheeg_src")

>> /usr/bin/python3 -m pip -q install --upgrade pip setuptools wheel
>> /usr/bin/python3 -m pip -q install synapseclient scikit-learn matplotlib tqdm
>> /usr/bin/python3 -m pip -q install mne mne-connectivity xmltodict numpy scipy pandas joblib
>> /usr/bin/python3 -m pip -q uninstall -y torcheeg || true
>> rm -rf torcheeg_src
>> git clone --depth 1 --branch v1.1.3 https://github.com/torcheeg/torcheeg.git torcheeg_src
>> /usr/bin/python3 -m pip -q install ./torcheeg_src


In [2]:
import numpy as np
import scipy.io as sio
from scipy.signal import butter, filtfilt, lfilter

import torch
import torch.nn as nn

from torcheeg import transforms
from torcheeg.datasets.constants import FACED_CHANNEL_LOCATION_DICT


MAT_PATH  = "/content/capture.mat"
CKPT_PATH = "/content/cnn_faced_best.pt"

WINDOW_START = 0
WINDOW_END   = 250

FALLBACK_FS = 250.0

MAT_CHANNEL_NAMES = ["FP1", "FZ", "FP2"]

EMO_FALLBACK = ["anger","disgust","fear","sadness","neutral","amusement","inspiration","joy","tenderness"]

BANDS = [
    ("delta", 0.5, 4.0),
    ("theta", 4.0, 8.0),
    ("alpha", 8.0, 13.0),
    ("beta",  13.0, 30.0),
    ("gamma", 30.0, None),
]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


class SmallCNN(nn.Module):
    def __init__(self, in_channels, num_classes=9):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),

            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        return self.net(x)


def _safe_filter(x, b, a):
    try:
        return filtfilt(b, a, x, axis=-1)
    except ValueError:
        return lfilter(b, a, x, axis=-1)


def butter_filter_1d(x, fs, low=None, high=None, order=4):
    nyq = fs / 2.0
    if nyq <= 0:
        raise ValueError(f"Bad fs={fs}")

    if low is not None and low >= nyq:
        return np.zeros_like(x)
    if high is not None and high >= nyq:
        high = nyq - 1e-3
        if high <= 0:
            return np.zeros_like(x)

    if low is None and high is None:
        return x

    if low is None:
        wn = high / nyq
        b, a = butter(order, wn, btype="lowpass")
    elif high is None:
        wn = low / nyq
        b, a = butter(order, wn, btype="highpass")
    else:
        if low >= high:
            return np.zeros_like(x)
        wn = [low / nyq, high / nyq]
        b, a = butter(order, wn, btype="bandpass")

    return _safe_filter(x, b, a)


def compute_de_BxC(eeg_CxT, fs, eps=1e-8):
    C, _ = eeg_CxT.shape
    F = np.zeros((len(BANDS), C), dtype=np.float32)

    for bi, (_, lo, hi) in enumerate(BANDS):
        for ci in range(C):
            x = eeg_CxT[ci].astype(np.float64)
            xf = butter_filter_1d(x, fs, low=lo, high=hi, order=4)
            var = np.var(xf) + eps
            de = 0.5 * np.log(2.0 * np.pi * np.e * var)
            F[bi, ci] = np.float32(de)

    return F


def load_eeg_from_mat(mat_path, fallback_fs=250.0):
    mat = sio.loadmat(mat_path)

    fs = None
    if "meta" in mat:
        try:
            meta = mat["meta"][0, 0]
            if "achieved_hz" in meta.dtype.names:
                fs = float(meta["achieved_hz"].item())
            elif "target_hz" in meta.dtype.names:
                fs = float(meta["target_hz"].item())
        except Exception:
            fs = None
    if fs is None:
        fs = float(fallback_fs)

    if "data" in mat and isinstance(mat["data"], np.ndarray) and mat["data"].ndim == 2:
        arr = mat["data"]
        if arr.shape[1] >= 2:
            eeg_TxC = arr[:, 1:]
            return eeg_TxC.T.astype(np.float64), fs

    for key in ["eeg", "EEG", "X", "signals", "data_eeg"]:
        if key in mat and isinstance(mat[key], np.ndarray):
            x = np.squeeze(mat[key])
            if x.ndim == 2:
                if x.shape[0] <= x.shape[1]:
                    return x.astype(np.float64), fs
                else:
                    return x.T.astype(np.float64), fs

    raise KeyError(f"Couldn't find EEG array in {mat_path}. Keys: {list(mat.keys())}")


def faced_channel_order():
    return list(FACED_CHANNEL_LOCATION_DICT.keys())


def build_full_faced_feature(F_BxC_small, mat_channel_names):
    faced_ch = faced_channel_order()
    F_full = np.zeros((len(BANDS), len(faced_ch)), dtype=np.float32)

    name_to_idx = {name: i for i, name in enumerate(faced_ch)}
    for j, name in enumerate(mat_channel_names):
        if name not in name_to_idx:
            raise ValueError(f"Channel name '{name}' not in FACED_CHANNEL_LOCATION_DICT keys.")
        F_full[:, name_to_idx[name]] = F_BxC_small[:, j]

    return F_full


def togrid_like_training(F_full_BxC):
    to_grid = transforms.ToGrid(FACED_CHANNEL_LOCATION_DICT)
    out = to_grid(eeg=F_full_BxC.T)  # (5,30) -> (30,5)
    return out["eeg"] if isinstance(out, dict) else out


def totensor_like_training(G_BxHxW):
    to_tensor = transforms.ToTensor()
    out = to_tensor(eeg=G_BxHxW)
    return out["eeg"] if isinstance(out, dict) else out


def normalize_like_training(x):
    mu = x.mean(dim=(2, 3), keepdim=True)
    sd = x.std(dim=(2, 3), keepdim=True).clamp_min(1e-6)
    return (x - mu) / sd


def load_model(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=DEVICE)
    state = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt

    classes = EMO_FALLBACK
    if isinstance(ckpt, dict) and "meta" in ckpt and isinstance(ckpt["meta"], dict):
        if ckpt["meta"].get("classes") is not None:
            classes = ckpt["meta"]["classes"]

    model = SmallCNN(in_channels=5, num_classes=len(classes)).to(DEVICE)
    model.load_state_dict(state)
    model.eval()
    return model, classes


def main(topk=3):
    eeg_CxT, fs = load_eeg_from_mat(MAT_PATH, fallback_fs=FALLBACK_FS)

    if eeg_CxT.shape[0] < 3:
        raise ValueError(f"Need at least 3 channels in mat, got {eeg_CxT.shape[0]}")
    eeg3 = eeg_CxT[:3, :]

    eeg_win = eeg3[:, WINDOW_START:WINDOW_END]
    if eeg_win.shape[1] < 16:
        raise ValueError(f"Window too short: {eeg_win.shape}")

    F_small = compute_de_BxC(eeg_win, fs)
    F_full = build_full_faced_feature(F_small, MAT_CHANNEL_NAMES)
    G = togrid_like_training(F_full)

    t = totensor_like_training(G)
    x = t.float().unsqueeze(0).to(DEVICE)
    x = normalize_like_training(x)

    model, classes = load_model(CKPT_PATH)

    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]

    top = probs.argsort()[-topk:][::-1]
    print(f"fs={fs:.2f} Hz | window=[{WINDOW_START}:{WINDOW_END}] | x={tuple(x.shape)}")
    for i, k in enumerate(top, 1):
        print(f"{i}. {classes[k]}\t{probs[k]:.6f}")


if __name__ == "__main__":
    main(topk=9)

fs=48.80 Hz | window=[0:250] | x=(1, 5, 8, 9)
1. fear	0.999450
2. anger	0.000550
3. tenderness	0.000000
4. sadness	0.000000
5. inspiration	0.000000
6. disgust	0.000000
7. neutral	0.000000
8. joy	0.000000
9. amusement	0.000000
