# WESAD 

In [1]:
import os
import torch
import numpy as np
from pathlib import Path

from scipy import signal
from scipy.stats import zscore
from sklearn.model_selection import train_test_split

In [2]:
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt


# -------------------- DEFINE CUSTOM TEMPLATE -------------------- #
pio.templates['draft'] = go.layout.Template(layout=dict(
    margin=dict(l=50, r=50, b=50, t=50),
    legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
)
))
pio.templates.default = "plotly+draft"

## Read the records

In [3]:
# 
data_path = Path("../../datasets/WESAD")
subject_id = [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 13, 14, 15, 16, 17]

wesad = []
for id in subject_id:
    subject_data = np.load(data_path / f"S{id}/S{id}.pkl", allow_pickle=True, encoding='bytes')
    print(f"reading {data_path / f"S{id}/S{id}.pkl"} ...")

    decoded_dict = {}
    for signal_key in subject_data[b'signal'][b'chest'].keys():
        key = signal_key.decode()
        decoded_dict[key] = subject_data[b'signal'][b'chest'][signal_key]

    decoded_dict['label'] = subject_data[b'label']

    wesad.append((id, decoded_dict))

print("number of records:", len(wesad))

reading ../../datasets/WESAD/S2/S2.pkl ...
reading ../../datasets/WESAD/S3/S3.pkl ...
reading ../../datasets/WESAD/S4/S4.pkl ...
reading ../../datasets/WESAD/S5/S5.pkl ...
reading ../../datasets/WESAD/S6/S6.pkl ...
reading ../../datasets/WESAD/S7/S7.pkl ...
reading ../../datasets/WESAD/S8/S8.pkl ...
reading ../../datasets/WESAD/S9/S9.pkl ...
reading ../../datasets/WESAD/S10/S10.pkl ...
reading ../../datasets/WESAD/S11/S11.pkl ...
reading ../../datasets/WESAD/S13/S13.pkl ...
reading ../../datasets/WESAD/S14/S14.pkl ...
reading ../../datasets/WESAD/S15/S15.pkl ...
reading ../../datasets/WESAD/S16/S16.pkl ...
reading ../../datasets/WESAD/S17/S17.pkl ...
number of records: 15


In [4]:
wesad[0][1].keys()

dict_keys(['ACC', 'ECG', 'EMG', 'EDA', 'Temp', 'Resp', 'label'])

#### Decimate the records from 700 Hz to 350 Hz

In [5]:
decimate = True

if decimate: # decimate the signals by factor of 2
    temp_dir = Path("wesad-350hz")
    freq = 350 # frequency after decimate

    for id, data in wesad:
        for key in data.keys():
            
            if key == 'label':
                data[key] = data['label'][::2]
            else:
                data[key] = signal.decimate(data[key], q=2, axis=0)

else:
    freq = 700  # original frequency

#### Process the Windows

The following IDs are provided: 0 = not defined / transient, 1 = baseline, 2 = stress, 3 = amusement, 4 = meditation, 5/6/7 = should be ignored in this dataset 

In [6]:
from numpy.lib.stride_tricks import sliding_window_view

def split_sample_signal(sample: dict[str,np.ndarray], labels: list, window_size: int, stride: int=None):
    """
    
    Parameters
    ----------
    sample : dict[str, np.ndarray]
        Subject record.
    labels : int
        Labels of interest.
    window_size : int
        Size of the windows in number of observations.
    stride : int, optional
        Number of observations to skip during the slide window. Default 20% of the 'window_size'.

    Returns
    -------
    dict[str,]
        Dictionary with the signal windows and labels.
    
    """
    if stride is None:
       stride = int(window_size*0.2)

    # split the labels with a sliding window
    windows_labels = sliding_window_view(sample['label'], window_size)[::stride,:]

    # ----------------
    # Get the label based on the mode 
    # ----------------
    def get_label(window: np.ndarray):
        values, counts = np.unique(window, return_counts=True)
        return values[np.argmax(counts)]

    windows_label = np.apply_along_axis(get_label, axis=1, arr=windows_labels)
    assert windows_labels.shape[0] == len(windows_label)

    # filter the interest class
    filter_mask = np.where(np.isin(windows_label, labels))[0]
    filtered_label = windows_label[filter_mask]
    print("num windows: ", len(filtered_label))

    # ----------------
    # split the signals and filter them based on the labels filter_mask
    # ----------------
    sample_windows = {'label': filtered_label}
    for key in ['ACC', 'ECG', 'EMG', 'EDA', 'Temp', 'Resp']:

        windows = sliding_window_view(sample[key], window_size, axis=0)[::stride,:]
        windows = windows[filter_mask]

        sample_windows[key] = windows


    return sample_windows


In [9]:
duration = 2 # windows size in seconds

wesad_data = []
wesad_labels = []
wesad_id = []
for id, data in wesad:
    # data_keys = ['ACC', 'ECG', 'EMG', 'EDA', 'Temp', 'Resp']  # all signals
    # data_keys = ['ECG', 'EMG', 'EDA', 'Temp', 'Resp']         # accelerometer not used
    # data_keys = ['ECG']                                       # only ECG
    data_keys = ['EMG', 'EDA', 'Temp', 'Resp']
    print(f"[subject {id}] ", end="")

    sample_windows = split_sample_signal(data, labels=[1,2,3,4], window_size=duration*freq)

    # concatenate all signal data under the same window
    sample_data = np.concatenate([sample_windows[k] for k in data_keys], axis=1) # [window x channels x time]
    sample_label = sample_windows['label']
    sample_id = np.full_like(sample_label, id)

    wesad_data.append(sample_data)
    wesad_labels.append(sample_label)
    wesad_id.append(sample_id)

[subject 2] num windows:  7222
[subject 3] num windows:  7337
[subject 4] num windows:  7425
[subject 5] num windows:  7528
[subject 6] num windows:  7473
[subject 7] num windows:  7471
[subject 8] num windows:  7513
[subject 9] num windows:  7474
[subject 10] num windows:  7683
[subject 11] num windows:  7547
[subject 13] num windows:  7552
[subject 14] num windows:  7553
[subject 15] num windows:  7567
[subject 16] num windows:  7532
[subject 17] num windows:  7517


In [10]:
wesad_data[0].shape

(7222, 4, 700)

#### Split between train an validation

[**Recomended** in case of four domain] Select only a subset with 20% of windows stratified by class

In [11]:
choice_subset = True

processed_data = []
processed_labels = []
processed_id = []
if choice_subset:

    np.random.seed(123) # set seed for reprodutibility
    list_indices = []
    
    for _data, _label, _id in zip(wesad_data, wesad_labels, wesad_id):

        # choice a subset of samples stratified by the labels
        indices, _ = train_test_split(np.arange(len(_data)), train_size=0.2, stratify=_label)
        indices.sort()
        
        # select the subset
        processed_data.append(_data[indices])
        processed_labels.append(_label[indices])
        processed_id.append(_id[indices])

        # record the selected indices
        list_indices.append(indices)

else:
    processed_data = wesad_data
    processed_labels = wesad_labels
    processed_id = wesad_id

Split between train an validation

In [12]:
# Train and validation set 
train_wesad_data = np.concatenate(processed_data[:-1], axis=0)
train_wesad_labels = np.concatenate(processed_labels[:-1], axis=0)
train_wesad_id = np.concatenate(processed_id[:-1], axis=0)

# Test set
test_wesad_data = processed_data[-1]
test_wesad_labels = processed_labels[-1]
test_wesad_id = processed_id[-1]

# z-normalize the signals [mean=0 and std=1]
train_wesad_data = zscore(train_wesad_data, axis=2)
test_wesad_data = zscore(test_wesad_data, axis=2)

print(train_wesad_data.shape, test_wesad_data.shape)

(20969, 4, 700) (1503, 4, 700)


#### Resample from 350 Hz to 360 Hz

In [13]:
# ----------------
# Resample the records from 350 Hz to 360 Hz
# ----------------
to_360hz = True

if to_360hz:
    temp_dir = Path(f"wesad-360hz-{duration}second")

    resampled_size = int(duration * 360) # new size

    test_wesad_data = np.apply_along_axis(signal.resample, axis=2, arr=test_wesad_data, num=resampled_size)
    train_wesad_data = np.apply_along_axis(signal.resample, axis=2, arr=train_wesad_data, num=resampled_size)

    freq = 360 # frequency after resampling

    print(f"train: {train_wesad_data.shape} | test: {test_wesad_data.shape}")

train: (20969, 4, 720) | test: (1503, 4, 720)


In [14]:
univariate = True

if univariate and train_wesad_data.shape[1] > 1:
    # ----------------
    # give label and id for each channel
    # ----------------
    train_wesad_labels = np.repeat(train_wesad_labels.reshape((-1,1)), 
                                repeats=train_wesad_data.shape[1], axis=1) # [Batch X channel]
    train_wesad_id = np.repeat(train_wesad_id.reshape((-1,1)), 
                            repeats=train_wesad_data.shape[1], axis=1) # [Batch X channel]

    test_wesad_labels = np.repeat(test_wesad_labels.reshape((-1,1)), 
                                repeats=train_wesad_data.shape[1], axis=1) # [Batch X channel]
    test_wesad_id = np.repeat(test_wesad_id.reshape((-1,1)),
                            repeats=train_wesad_data.shape[1], axis=1) # [Batch X channel]

    # ----------------
    # isolate each channel
    # ----------------
    test_wesad_data = np.reshape(test_wesad_data, (-1, 1, test_wesad_data.shape[2]))
    test_wesad_labels = np.reshape(test_wesad_labels, (-1,))
    test_wesad_id = np.reshape(test_wesad_id, (-1,))

    train_wesad_data = np.reshape(train_wesad_data, (-1, 1, train_wesad_data.shape[2]))
    train_wesad_labels = np.reshape(train_wesad_labels, (-1,))
    train_wesad_id = np.reshape(train_wesad_id, (-1,))

train_wesad_data.shape, test_wesad_data.shape

((83876, 1, 720), (6012, 1, 720))

In [15]:
px.line(y=test_wesad_data[0,0])

#### Save the dataset

In [16]:
train_wesad_data.shape

(83876, 1, 720)

In [17]:
np.unique(train_wesad_id, return_counts=True)

(array([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 13, 14, 15, 16],
       dtype=int32),
 array([5776, 5868, 5940, 6020, 5976, 5976, 6008, 5976, 6144, 6036, 6040,
        6040, 6052, 6024]))

In [18]:
duration, freq

(2, 360)

In [19]:
# check the output dir
temp_dir = Path(f"wesad-{freq}hz-{duration}second-no-ecg")
os.makedirs(temp_dir, exist_ok=True)

torch.save({
    'samples': torch.from_numpy(train_wesad_data),
    'labels': torch.from_numpy(train_wesad_labels),
    'metadata': torch.from_numpy(train_wesad_id),
    }, f=temp_dir / "train.pt")
torch.save({
    'samples': torch.from_numpy(test_wesad_data),
    'labels': torch.from_numpy(test_wesad_labels),
    'metadata': torch.from_numpy(test_wesad_id),
    }, f=temp_dir / "test.pt")