In [None]:
# Configuration and Imports
import os, json, pickle, random
import numpy as np
import jax
import jax.numpy as jnp
from typing import Dict, List, Tuple

# ============ CONFIGURATION ============
BASE_DIR = 'path/to/your/data'  # Replace with actual data directory
MOVEMENT_ALIASES = {
    'backwards': 'backward',
    'backward': 'backward',
    'forward': 'forward',
    'landing': 'landing',
    'left': 'left',
    'right': 'right',
    'takeoff': 'takeoff',
    'take_off': 'takeoff'
}
KNOWN_MOVEMENTS = set(MOVEMENT_ALIASES.values())
RNG_SEED = 42
MAX_ROWS_PER_FILE = 1500  # Set to None for all rows
TRAIN_RATIO = 0.8
VAR_SMOOTHING = 1e-9
np.random.seed(RNG_SEED)

## Data Loading

In [None]:
import pandas as pd

def is_movement_dir(path_part: str) -> bool:
    canon = MOVEMENT_ALIASES.get(path_part.lower())
    return canon in KNOWN_MOVEMENTS if canon else False

def canonical_movement(path_part: str) -> str:
    return MOVEMENT_ALIASES.get(path_part.lower())

def gather_files(base_dir: str) -> Dict[str, List[str]]:
    movement_files: Dict[str, List[str]] = {m: [] for m in KNOWN_MOVEMENTS}
    for root, dirs, files in os.walk(base_dir):
        tail = os.path.basename(root)
        if is_movement_dir(tail):
            label = canonical_movement(tail)
            for f in files:
                if f.lower().endswith('.csv'):
                    movement_files[label].append(os.path.join(root, f))
    return {k: v for k, v in movement_files.items() if v}

movement_to_files = gather_files(BASE_DIR)
print(f'Found movements: {list(movement_to_files.keys())}')

In [None]:
def _read_eeg_file(fp: str):
    for attempt in ['\t', ',', 'whitespace']:
        try:
            if attempt == 'whitespace':
                df = pd.read_csv(fp, sep='\s+', header=None, engine='python')
            else:
                df = pd.read_csv(fp, sep=attempt, header=None, engine='python')
            if df.shape[0] == 0:
                continue
            arr = df.select_dtypes(include=[float, int]).to_numpy(dtype=np.float32)
            if arr.size == 0:
                continue
            return arr
        except Exception:
            continue
    return None

def load_samples(movement_files: Dict[str, List[str]], max_rows_per_file: int = MAX_ROWS_PER_FILE):
    X_candidates: List[np.ndarray] = []
    y_candidates: List[np.ndarray] = []
    col_counts: List[int] = []
    label_to_index: Dict[str,int] = {}
    skipped_fewer, skipped_empty = 0, 0

    for label in sorted(movement_files.keys()):
        if label not in label_to_index:
            label_to_index[label] = len(label_to_index)
        idx = label_to_index[label]
        for fp in movement_files[label]:
            arr = _read_eeg_file(fp)
            if arr is None:
                skipped_empty += 1
                continue
            if max_rows_per_file is not None and arr.shape[0] > max_rows_per_file:
                arr = arr[:max_rows_per_file]
            X_candidates.append(arr)
            y_candidates.append(np.full((arr.shape[0],), idx, dtype=np.int32))
            col_counts.append(arr.shape[1])

    unique, counts = np.unique(np.array(col_counts), return_counts=True)
    target_dim = int(unique[np.argmax(counts)])

    X_parts, y_parts = [], []
    for arr, y_part in zip(X_candidates, y_candidates):
        if arr.shape[1] == target_dim:
            X_parts.append(arr)
            y_parts.append(y_part)
        elif arr.shape[1] > target_dim:
            X_parts.append(arr[:, :target_dim])
            y_parts.append(y_part)
        else:
            skipped_fewer += 1

    X = np.concatenate(X_parts, axis=0)
    y = np.concatenate(y_parts, axis=0)
    index_to_label = {v: k for k, v in label_to_index.items()}

    print(f'Loaded: {X.shape[0]} samples, {target_dim} features, {len(index_to_label)} classes')
    return X, y, index_to_label

X_raw, y_raw, index_to_label = load_samples(movement_to_files)
print('Data shape:', X_raw.shape)

## Feature Engineering

In [None]:
def standardize(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    mean = X.mean(axis=0)
    std = X.std(axis=0) + 1e-8
    return (X - mean) / std, mean, std

# Variance-based filtering + derived statistics + nonlinear transforms
raw_var = X_raw.var(axis=0)
var_threshold = 1e-3
mask = raw_var > var_threshold
X_var = X_raw[:, mask]

row_mean = X_var.mean(axis=1)
row_std = X_var.std(axis=1)
row_abs_mean = np.abs(X_var).mean(axis=1)
row_energy = np.square(X_var).mean(axis=1)

X_abs = np.abs(X_var)
X_log1p = np.log1p(np.abs(X_var))

X_feat = np.concatenate([X_var, X_abs, X_log1p,
                         row_mean[:, None], row_std[:, None],
                         row_abs_mean[:, None], row_energy[:, None]], axis=1)

X, feat_mean, feat_std = standardize(X_feat)
print(f'Engineered features: {X.shape}')

## Train/Test Split & Model Fitting

In [None]:
def stratified_split(X: np.ndarray, y: np.ndarray, train_ratio: float, seed: int = RNG_SEED):
    rng = np.random.default_rng(seed)
    train_idx, test_idx = [], []
    for cls in np.unique(y):
        cls_idx = np.where(y == cls)[0]
        rng.shuffle(cls_idx)
        n_train = int(len(cls_idx) * train_ratio)
        train_idx.append(cls_idx[:n_train])
        test_idx.append(cls_idx[n_train:])
    return X[np.concatenate(train_idx)], y[np.concatenate(train_idx)], X[np.concatenate(test_idx)], y[np.concatenate(test_idx)]

X_train, y_train, X_test, y_test = stratified_split(X, y_raw, TRAIN_RATIO)
print(f'Train: {X_train.shape}, Test: {X_test.shape}')

In [None]:
class GaussianNBJAX:
    def __init__(self, var_smoothing: float = VAR_SMOOTHING, uniform_priors: bool = True):
        self.var_smoothing = float(var_smoothing)
        self.uniform_priors = bool(uniform_priors)
        self.class_prior_ = None
        self.theta_ = None
        self.var_ = None
    def fit(self, X: np.ndarray, y: np.ndarray):
        X_j = jnp.asarray(X, dtype=jnp.float32)
        y_j = jnp.asarray(y, dtype=jnp.int32)
        num_classes = int(jnp.max(y_j) + 1)
        counts = jnp.bincount(y_j, length=num_classes)
        counts_f = jnp.maximum(counts.astype(jnp.float32), 1.0)
        def sums_for_class(c):
            mask = (y_j == c)
            masked = jnp.where(mask[:, None], X_j, 0.0)
            return jnp.sum(masked, axis=0)
        def sums2_for_class(c):
            mask = (y_j == c)
            masked2 = jnp.where(mask[:, None], X_j * X_j, 0.0)
            return jnp.sum(masked2, axis=0)
        classes = jnp.arange(num_classes)
        sums = jax.vmap(sums_for_class)(classes)
        sums2 = jax.vmap(sums2_for_class)(classes)
        means = sums / counts_f[:, None]
        vars_ = (sums2 / counts_f[:, None]) - jnp.square(means)
        vars_ = jnp.maximum(vars_, self.var_smoothing)
        if self.uniform_priors:
            priors = jnp.ones(num_classes, dtype=jnp.float32) / num_classes
        else:
            priors = counts.astype(jnp.float32) / float(X_j.shape[0])
        self.theta_ = np.asarray(means, dtype=np.float32)
        self.var_ = np.asarray(vars_, dtype=np.float32)
        self.class_prior_ = np.asarray(priors, dtype=np.float32)
        return self
    @staticmethod
    @jax.jit
    def _predict_log_proba_jit(X: jnp.ndarray, mu: jnp.ndarray, var: jnp.ndarray, log_prior: jnp.ndarray) -> jnp.ndarray:
        const_term = -0.5 * jnp.sum(jnp.log(2.0 * jnp.pi * var), axis=1)
        diff = X[:, None, :] - mu[None, :, :]
        quad = -0.5 * jnp.sum((diff * diff) / (var[None, :, :]), axis=2)
        log_lik = quad + const_term[None, :]
        return log_lik + log_prior[None, :]
    def predict_log_proba(self, X: np.ndarray) -> np.ndarray:
        X_j = jnp.asarray(X, dtype=jnp.float32)
        mu = jnp.asarray(self.theta_, dtype=jnp.float32)
        var = jnp.asarray(self.var_, dtype=jnp.float32)
        log_prior = jnp.log(jnp.asarray(self.class_prior_, dtype=jnp.float32) + 1e-12)
        out = self._predict_log_proba_jit(X_j, mu, var, log_prior)
        return np.asarray(out)
    def predict(self, X: np.ndarray) -> np.ndarray:
        return np.argmax(self.predict_log_proba(X), axis=1).astype(np.int32)
    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        logp = self.predict_log_proba(X)
        m = np.max(logp, axis=1, keepdims=True)
        p = np.exp(logp - m)
        p /= np.sum(p, axis=1, keepdims=True)
        return p

model = GaussianNBJAX(var_smoothing=VAR_SMOOTHING, uniform_priors=True)
model.fit(X_train, y_train)
print('Model trained.')

## Evaluation

In [None]:
def confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    C = int(max(y_true.max(), y_pred.max()) + 1)
    mat = np.zeros((C, C), dtype=np.int32)
    for t, p in zip(y_true, y_pred):
        mat[t, p] += 1
    return mat

def classification_report(y_true: np.ndarray, y_pred: np.ndarray, index_to_label: Dict[int,str]) -> Dict:
    C = int(max(y_true.max(), y_pred.max()) + 1)
    report = {}
    for c in range(C):
        tp = np.sum((y_true == c) & (y_pred == c))
        fp = np.sum((y_true != c) & (y_pred == c))
        fn = np.sum((y_true == c) & (y_pred != c))
        precision = tp / (tp + fp + 1e-12)
        recall = tp / (tp + fn + 1e-12)
        f1 = 2 * precision * recall / (precision + recall + 1e-12)
        report[index_to_label[c]] = {
            'precision': float(precision), 'recall': float(recall),
            'f1': float(f1), 'support': int(np.sum(y_true == c))
        }
    report['overall'] = {'accuracy': float(np.mean(y_true == y_pred))}
    return report

y_pred = model.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
report = classification_report(y_test, y_pred, index_to_label)
accuracy = report['overall']['accuracy']

print(f'Accuracy: {accuracy:.4f}')
print('\nPer-class metrics:')
for cls, metrics in report.items():
    if cls != 'overall':
        print(f"  {cls}: precision={metrics['precision']:.3f}, recall={metrics['recall']:.3f}, f1={metrics['f1']:.3f}")

## Save Model

In [None]:
# Save model to pickle
model_path = 'gaussiannb_jax_model.pkl'
model_metadata = {
    'accuracy': accuracy,
    'n_train': int(X_train.shape[0]),
    'n_test': int(X_test.shape[0]),
    'n_features': int(X_train.shape[1]),
    'classes': [index_to_label[i] for i in sorted(index_to_label.keys())],
    'var_smoothing': float(model.var_smoothing),
    'uniform_priors': True
}

model_state = {
    'theta_': model.theta_,
    'var_': model.var_,
    'class_prior_': model.class_prior_,
    'var_smoothing': model.var_smoothing,
    'uniform_priors': True,
    'metadata': model_metadata,
    'index_to_label': index_to_label,
    'label_to_index': {v: k for k, v in index_to_label.items()}
}

with open(model_path, 'wb') as f:
    pickle.dump(model_state, f, protocol=pickle.HIGHEST_PROTOCOL)

print(f'Model saved to: {model_path}')

## Load & Verify Model

In [None]:
def load_gaussiannb_jax_model(path: str):
    with open(path, 'rb') as f:
        state = pickle.load(f)
    model_loaded = GaussianNBJAX(
        var_smoothing=state['var_smoothing'],
        uniform_priors=state['uniform_priors']
    )
    model_loaded.theta_ = state['theta_']
    model_loaded.var_ = state['var_']
    model_loaded.class_prior_ = state['class_prior_']
    return model_loaded, state

model_loaded, state = load_gaussiannb_jax_model(model_path)
print(f'Model loaded from {model_path}')
print(f'Classes: {state["metadata"]["classes"]}')
print(f'Accuracy: {state["metadata"]["accuracy"]:.4f}')