In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import pickle
import numpy as np
import pandas as pd
from scipy.signal import resample
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import joblib

# Path to your WESAD folder
DATA_ROOT = "/content/drive/MyDrive/WESAD"
MODEL_PATH = "/content/drive/MyDrive/cognivox_wesad_lite.joblib" # Note the new filename

# Sampling rate of wrist data (We use ACC as reference which is 32Hz)
FS_WRIST = 32

# Window settings (10s window, 5s step)
WINDOW_SEC = 10
STEP_SEC = 5
WINDOW = WINDOW_SEC * FS_WRIST
STEP = STEP_SEC * FS_WRIST

subjects = sorted([d for d in os.listdir(DATA_ROOT) if d.startswith("S")])
print("Subjects:", subjects)

Mounted at /content/drive
Subjects: ['S10', 'S11', 'S13', 'S14', 'S15', 'S16', 'S17', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']


In [None]:
def load_subject(subject_id):
    pkl_path = os.path.join(DATA_ROOT, subject_id, f"{subject_id}.pkl")
    with open(pkl_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")
    return data

def map_label_multiclass_to_binary(lbl):
    # 2 = Stress, [1, 3, 4] = Non-Stress, Others = Ignore
    if lbl == 2:
        return 1
    elif lbl in [1, 3, 4]:
        return 0
    else:
        return -1

def get_wrist_signals_and_labels(sub_data):
    wrist = sub_data["signal"]["wrist"]

    # 1. Load Accelerometer (Reference: 32Hz)
    acc = np.array(wrist["ACC"])

    # 2. Load BVP (Originally 64Hz) and Resample to 32Hz to match ACC
    bvp_original = np.array(wrist["BVP"]).flatten()
    bvp = resample(bvp_original, len(acc))

    # 3. Load Labels and Resample to match ACC
    labels_full = np.array(sub_data["label"])
    labels_resampled = resample(labels_full, len(acc))
    labels_resampled = np.rint(labels_resampled).astype(int)

    # Note: We completely ignore EDA and TEMP here
    return bvp, acc, labels_resampled

In [None]:
# def extract_window_features(bvp_w, acc_w):
#     """
#     bvp_w: 1D array (window)
#     acc_w: 2D array (window, 3)
#     returns: dict of features
#     """
#     feats = {}

#     # --- BVP (Heart Rate Proxy) ---
#     feats["bvp_mean"] = float(np.mean(bvp_w))
#     feats["bvp_std"] = float(np.std(bvp_w))

#     # --- ACC (Motion) ---
#     # Calculate magnitude: sqrt(x^2 + y^2 + z^2)
#     mag = np.linalg.norm(acc_w, axis=1)
#     feats["acc_mag_mean"] = float(np.mean(mag))
#     feats["acc_mag_std"] = float(np.std(mag))

#     return feats
def extract_window_features(bvp_w, acc_w):
    feats = {}

    # --- ENHANCED BVP (Heart Rate) FEATURES ---
    # simple statistics
    feats["bvp_mean"] = float(np.mean(bvp_w))
    feats["bvp_std"] = float(np.std(bvp_w))
    feats["bvp_min"] = float(np.min(bvp_w))
    feats["bvp_max"] = float(np.max(bvp_w))

    # Range (Max - Min) helps detect spikes
    feats["bvp_range"] = feats["bvp_max"] - feats["bvp_min"]

    # Energy (Sum of squares) can indicate intensity
    feats["bvp_energy"] = float(np.sum(bvp_w**2)) / len(bvp_w)

    # --- ACC (Motion) ---
    mag = np.linalg.norm(acc_w, axis=1)
    feats["acc_mean"] = float(np.mean(mag))
    feats["acc_std"] = float(np.std(mag))
    feats["acc_max"] = float(np.max(mag)) # Max movement intensity

    return feats

def build_subject_windows(subject_id):
    sub_data = load_subject(subject_id)
    # Get only BVP and ACC
    bvp, acc, labels = get_wrist_signals_and_labels(sub_data)

    n = len(acc)
    rows = []

    for start in range(0, n - WINDOW, STEP):
        end = start + WINDOW

        lab_window = labels[start:end]

        # Check label majority
        vals, counts = np.unique(lab_window, return_counts=True)
        if len(vals) == 0: continue
        maj_label = vals[np.argmax(counts)]
        bin_label = map_label_multiclass_to_binary(maj_label)

        if bin_label == -1:
            continue

        # Extract features (No EDA/Temp passed)
        feats = extract_window_features(
            bvp[start:end],
            acc[start:end]
        )
        feats["label"] = bin_label
        feats["subject"] = subject_id
        rows.append(feats)

    return rows

In [None]:
# 2. Re-build Dataset with New Features
print("Re-building dataset with enhanced features...")
all_rows = []
for sid in subjects:
    try:
        # Re-using the build logic but it will call the NEW extract_window_features above
        rows = build_subject_windows(sid)
        all_rows.extend(rows)
    except Exception as e:
        print(f"Error with {sid}: {e}")

df_lite = pd.DataFrame(all_rows)
print("New Dataset Shape:", df_lite.shape)

# 3. Train Model
feature_cols = [c for c in df_lite.columns if c not in ["label", "subject"]]
X = df_lite[feature_cols].values
y = df_lite["label"].values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

# Increased estimators to 300 to capture the extra detail
clf_lite = RandomForestClassifier(
    n_estimators=300,
    random_state=42,
    class_weight="balanced",  # Critical for fixing the low Recall
    n_jobs=-1
)
clf_lite.fit(X_train, y_train)

# 4. Show New Results
print("\n--- NEW ACCURACY RESULTS (Lite Model) ---")
y_pred = clf_lite.predict(X_test)
print(classification_report(y_test, y_pred))
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))

# 5. Save if results are good
MODEL_PATH = "/content/drive/MyDrive/cognivox_wesad_lite_enhanced.joblib"
joblib.dump(
    {
        "model": clf_lite,
        "feature_cols": feature_cols,
        "window_sec": WINDOW_SEC,
        "step_sec": STEP_SEC,
        "fs_wrist": FS_WRIST,
    },
    MODEL_PATH,
)
print(f"Enhanced model saved to: {MODEL_PATH}")

Re-building dataset with enhanced features...
New Dataset Shape: (8991, 11)

--- NEW ACCURACY RESULTS (Lite Model) ---
              precision    recall  f1-score   support

           0       0.91      0.96      0.93      1400
           1       0.83      0.65      0.73       399

    accuracy                           0.89      1799
   macro avg       0.87      0.80      0.83      1799
weighted avg       0.89      0.89      0.89      1799

Confusion Matrix:
 [[1348   52]
 [ 141  258]]
Enhanced model saved to: /content/drive/MyDrive/cognivox_wesad_lite_enhanced.joblib
