In [7]:
import numpy as np
from scipy import stats
import pickle

# Define constants
DATA_DIR = 'WESAD/'
LABEL_SF = 700
SF_DICT = {'ACC': 32, 'BVP': 64, 'EDA': 4, 'TEMP': 4}
FEATURES = ['ACC', 'BVP', 'EDA', 'TEMP']
INVALID_LABELS = [0, 3, 5, 6, 7]

def load_pickle(file_path):
    with open(file_path, 'rb') as file:
        data = pickle.load(file, encoding='latin1')
    return data

def create_labels(all_labels, sf):
    labels = [all_labels[x] for x in range(0, len(all_labels), sf // 2)]
    return np.array(labels)

def create_windows(data, sf, window_size=1):
    data_windows = [data[x : x + window_size * sf] for x in range(0, len(data), window_size * (sf // 2))]
    return data_windows

def normalize(data):
    mean = np.mean(np.mean(data, 0), 0)
    std = np.std(np.std(data, 0), 0)
    normalized_data = (data - mean) / std
    return normalized_data

def preprocess_data(data):
    preprocessed_data = {}
    for feature in FEATURES:
        preprocessed_data[feature] = create_windows(data[feature], SF_DICT[feature])
    return preprocessed_data

def filter_invalid_labels(labels):
    return [label for label in labels if label not in INVALID_LABELS]

def merge_labels(labels):
    merged_labels = []
    for label in labels:
        if label == 4:
            merged_labels.append(1)
        else:
            merged_labels.append(label - 1)
    return np.array(merged_labels)

def get_subject_data(subject):
    file_path = DATA_DIR + f'S{subject}/S{subject}.pkl'
    data = load_pickle(file_path)
    data_labels = data['label']
    window_labels = create_labels(data_labels, LABEL_SF)
    valid_labels = filter_invalid_labels(window_labels)
    preprocessed_data = preprocess_data(data['signal']['wrist'])
    preprocessed_data = {key: preprocessed_data[key][:len(valid_labels)] for key in preprocessed_data}
    merged_labels = merge_labels(valid_labels)
    return {'data': preprocessed_data, 'labels': merged_labels}

def combine_subject_data():
    combined_data = {'data': {feature: np.empty((0, SF_DICT[feature], 3 if feature == 'ACC' else 1)) for feature in FEATURES}, 'labels': []}
    for subject in range(2, 18):
        if subject != 12:
            subject_data = get_subject_data(str(subject))
            for feature in FEATURES:
                combined_data['data'][feature] = np.append(combined_data['data'][feature], subject_data['data'][feature], 0)
            combined_data['labels'] = np.append(combined_data['labels'], subject_data['labels'], 0)
    return combined_data

def save_data(data, file_path):
    with open(file_path, 'wb') as file:
        pickle.dump(data, file)

def save_formatted_data(folder_path):
    for subject in range(2, 18):
        if subject != 12:
            subject_data = get_subject_data(str(subject))
            normalized_data = {feature: normalize(subject_data['data'][feature]) for feature in FEATURES}
            data = {'data': normalized_data, 'labels': subject_data['labels']}
            save_data(data, folder_path + f'S{subject}.pkl')
    all_data = combine_subject_data()
    save_data(all_data, folder_path + 'All_ID.pkl')

def main():
    save_formatted_data('WESAD/')
    # You can add more functionality here

if __name__ == "__main__":
    main()

In [8]:
with open('WESAD/All_ID.pkl', 'rb') as file:
        data = pickle.load(file, encoding='latin1')

In [9]:
data

{'data': {'ACC': array([[[ 62., -21., 107.],
          [ 66.,  13.,  53.],
          [ 41.,   9.,  15.],
          ...,
          [ 40.,  17.,  44.],
          [ 49.,  19.,  40.],
          [ 46.,  21.,  26.]],
  
         [[ 53.,  21.,  -6.],
          [ 41.,  11.,  20.],
          [ 49.,   5.,  26.],
          ...,
          [ 56.,  16.,  36.],
          [ 54.,  17.,  26.],
          [ 47.,  16.,  29.]],
  
         [[ 48.,  24.,  15.],
          [ 50.,  24.,   7.],
          [ 53.,  26.,  -5.],
          ...,
          [ 56.,  22.,  14.],
          [ 62.,  22.,  11.],
          [ 57.,  22.,   6.]],
  
         ...,
  
         [[-59.,   3.,  27.],
          [-59.,   3.,  27.],
          [-59.,   3.,  26.],
          ...,
          [-59.,   3.,  26.],
          [-59.,   3.,  27.],
          [-59.,   3.,  26.]],
  
         [[-59.,   3.,  27.],
          [-59.,   4.,  27.],
          [-59.,   3.,  26.],
          ...,
          [-59.,   3.,  26.],
          [-59.,   3.,  26.],
       