In [3]:
!pip install numpy pandas scikit-learn xgboost matplotlib seaborn


Defaulting to user installation because normal site-packages is not writeable


In [None]:
# 1	Baseline: participant is sitting calmly before the task begins
# 2	Stress: participant is performing a stressful task (e.g., mental arithmetic under pressure)

In [7]:
import os
import pickle
import numpy as np
import pandas as pd
from scipy.signal import butter, filtfilt
import matplotlib.pyplot as plt
import seaborn as sns

# Dataset path
WESAD_PATH = r"D:\COAI_paper\WESAD\WESAD"

def butter_lowpass_filter(data, cutoff=3.0, fs=700, order=4):
    b, a = butter(order, cutoff / (0.5 * fs), btype='low')
    return filtfilt(b, a, data)

def extract_features_from_s2(data, start=200000, end=1800000, window_size=1500, step_size=750):
    ecg_raw = data['signal']['chest']['ECG']
    eda = data['signal']['chest']['EDA'][start:end]
    resp = data['signal']['chest']['Resp'][start:end]
    labels = data['label'][start:end]

    ecg = ecg_raw[start:end]
    if ecg.ndim == 2:
        ecg = ecg[:, 0]

    features, targets, subject_ids = [], [], []

    for i in range(0, len(labels) - window_size, step_size):
        seg_labels = labels[i:i + window_size]
        if not np.all(np.isin(seg_labels, [1, 2])):
            continue
        majority_label = np.bincount(seg_labels).argmax()

        ecg_seg = ecg[i:i + window_size]
        eda_seg = eda[i:i + window_size]
        resp_seg = resp[i:i + window_size]

        try:
            ecg_f = butter_lowpass_filter(ecg_seg)
            eda_f = butter_lowpass_filter(eda_seg)
            resp_f = butter_lowpass_filter(resp_seg)
        except Exception:
            continue

        feat = [
            np.mean(ecg_f), np.std(ecg_f),
            np.mean(eda_f), np.std(eda_f),
            np.mean(resp_f), np.std(resp_f)
        ]
        features.append(feat)
        targets.append(majority_label)
        subject_ids.append("S2")

    return features, targets, subject_ids

# Load S2
subject_path = os.path.join(WESAD_PATH, "S2", "S2.pkl")
with open(subject_path, 'rb') as f:
    data = pickle.load(f, encoding='latin1')

features, targets, sids = extract_features_from_s2(data)

print(f"‚úÖ Extracted segments: {len(features)}")
if len(features) == 0:
    raise ValueError("‚ùå Still no segments found. Something is wrong with data range or signal shape.")

# DataFrame
df = pd.DataFrame(features, columns=[
    'ECG_Mean', 'ECG_Std', 'EDA_Mean', 'EDA_Std', 'Resp_Mean', 'Resp_Std'
])
df['Label'] = targets
df['Subject'] = sids

print("‚úÖ Final dataset shape:", df.shape)
display(df.head())

# Plot
plt.figure(figsize=(8, 4))
sns.countplot(data=df, x='Label')
plt.title('S2 Stress vs. Baseline Segments')
plt.xlabel('Label (1 = Baseline, 2 = Stress)')
plt.ylabel('Count')
plt.tight_layout()
plt.show()


‚úÖ Extracted segments: 0


ValueError: ‚ùå Still no segments found. Something is wrong with data range or signal shape.

# Load & Preprocess WESAD Data

# Visualize Class Distribution

In [None]:
plt.figure(figsize=(12, 5))
sns.countplot(data=df, x='Label', hue='Subject')
plt.title('Distribution of Stress vs. Baseline Samples Across Subjects')
plt.xlabel('Label (1=Baseline, 2=Stress)')
plt.ylabel('Sample Count')
plt.legend(title='Subject', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


# Visualize Feature Distributions

In [None]:
plt.figure(figsize=(10, 6))
sns.boxplot(data=df, x='Label', y='ECG_Std')
plt.title('ECG Std by Stress Label')
plt.show()

plt.figure(figsize=(10, 6))
sns.boxplot(data=df, x='Label', y='EDA_Mean')
plt.title('EDA Mean by Stress Label')
plt.show()


In [35]:
# üîç Absolute label check for S2
with open(r"D:\COAI_paper\WESAD\WESAD\S2\S2.pkl", 'rb') as f:
    data = pickle.load(f, encoding='latin1')
    labels = data['label']
    unique, counts = np.unique(labels, return_counts=True)
    print("üîé Label frequencies in S2:")
    for u, c in zip(unique, counts):
        print(f"Label {u}: {c} samples")


üîé Label frequencies in S2:
Label 0: 2142701 samples
Label 1: 800800 samples
Label 2: 430500 samples
Label 3: 253400 samples
Label 4: 537599 samples
Label 6: 45500 samples
Label 7: 44800 samples
