In [7]:
import os
import pickle
import numpy as np
import pandas as pd
import csv
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType

Matplotlib created a temporary cache directory at /scratch/msawires1/job_39290047/matplotlib-c6qhsdk0 because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


# Combining Pkl Tables - split up by modality 
## i.e. chest [ECG, EMG, EDA, Temp, Resp, ACC] and wrist [BVP, EDA, Temp, ACC] each have a table of their own

In [30]:
# Chest ECG

subject_ids = [f"S{i}" for i in range(2, 18) if i != 12]
base_path = "../ialtamirano/raw_data/WESAD"
chest_ecg_rows = []

for subject in subject_ids:
    file_path = os.path.join(base_path, subject, f"{subject}.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")

    signal = data['signal']['chest']['ECG'].flatten()
    labels = data['label']
    N = min(len(signal), len(labels))

    chest_ecg_rows.extend([
        {"subject": subject, "label": int(labels[i]), "value": float(signal[i]), "sample": i}
        for i in range(N) if labels[i] in [1, 2, 3, 4]
    ])

df_chest_ecg = pd.DataFrame(chest_ecg_rows)
print("Chest ECG shape:", df_chest_ecg.shape)
df_chest_ecg.head()


Chest ECG shape: (31470603, 4)


Unnamed: 0,subject,label,value,sample
0,S2,1,0.030945,214583
1,S2,1,0.033646,214584
2,S2,1,0.033005,214585
3,S2,1,0.031815,214586
4,S2,1,0.03035,214587


In [32]:
#Showing that all subjects and labels are present

label_counts = df_chest_ecg.groupby(['subject', 'label']).size().reset_index(name='count')
print(label_counts)

   subject  label   count
0      S10      1  826000
1      S10      2  507500
2      S10      3  260400
3      S10      4  557200
4      S11      1  826000
5      S11      2  476000
6      S11      3  257600
7      S11      4  553701
8      S13      1  826001
9      S13      2  464800
10     S13      3  267400
11     S13      4  556499
12     S14      1  826000
13     S14      2  472500
14     S14      3  260401
15     S14      4  555800
16     S15      1  822500
17     S15      2  480200
18     S15      3  260400
19     S15      4  555799
20     S16      1  826000
21     S16      2  471101
22     S16      3  257600
23     S16      4  554399
24     S17      1  826700
25     S17      2  506100
26     S17      3  260400
27     S17      4  511700
28      S2      1  800800
29      S2      2  430500
30      S2      3  253400
31      S2      4  537599
32      S3      1  798000
33      S3      2  448000
34      S3      3  262500
35      S3      4  546001
36      S4      1  810601
37      S4  

In [33]:
# Chest EMG

chest_emg_rows = []

for subject in subject_ids:
    file_path = os.path.join(base_path, subject, f"{subject}.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")

    signal = data['signal']['chest']['EMG'].flatten()
    labels = data['label']
    N = min(len(signal), len(labels))

    chest_emg_rows.extend([
        {"subject": subject, "label": int(labels[i]), "value": float(signal[i]), "sample": i}
        for i in range(N) if labels[i] in [1, 2, 3, 4]
    ])

df_chest_emg = pd.DataFrame(chest_emg_rows)
print("Chest EMG shape:", df_chest_emg.shape)
df_chest_emg.head()


Chest EMG shape: (31470603, 4)


Unnamed: 0,subject,label,value,sample
0,S2,1,-0.003708,214583
1,S2,1,-0.014145,214584
2,S2,1,0.010208,214585
3,S2,1,0.012634,214586
4,S2,1,0.00206,214587


In [34]:
#Showing that all subjects and labels are present

label_counts = df_chest_emg.groupby(['subject', 'label']).size().reset_index(name='count')
print(label_counts)

   subject  label   count
0      S10      1  826000
1      S10      2  507500
2      S10      3  260400
3      S10      4  557200
4      S11      1  826000
5      S11      2  476000
6      S11      3  257600
7      S11      4  553701
8      S13      1  826001
9      S13      2  464800
10     S13      3  267400
11     S13      4  556499
12     S14      1  826000
13     S14      2  472500
14     S14      3  260401
15     S14      4  555800
16     S15      1  822500
17     S15      2  480200
18     S15      3  260400
19     S15      4  555799
20     S16      1  826000
21     S16      2  471101
22     S16      3  257600
23     S16      4  554399
24     S17      1  826700
25     S17      2  506100
26     S17      3  260400
27     S17      4  511700
28      S2      1  800800
29      S2      2  430500
30      S2      3  253400
31      S2      4  537599
32      S3      1  798000
33      S3      2  448000
34      S3      3  262500
35      S3      4  546001
36      S4      1  810601
37      S4  

In [35]:
# Chest EDA

chest_eda_rows = []

for subject in subject_ids:
    file_path = os.path.join(base_path, subject, f"{subject}.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")

    signal = data['signal']['chest']['EDA'].flatten()
    labels = data['label']
    N = min(len(signal), len(labels))

    chest_eda_rows.extend([
        {"subject": subject, "label": int(labels[i]), "value": float(signal[i]), "sample": i}
        for i in range(N) if labels[i] in [1, 2, 3, 4]
    ])

df_chest_eda = pd.DataFrame(chest_eda_rows)
print("Chest EDA shape:", df_chest_eda.shape)
df_chest_eda.head()


Chest EDA shape: (31470603, 4)


Unnamed: 0,subject,label,value,sample
0,S2,1,5.710983,214583
1,S2,1,5.719376,214584
2,S2,1,5.706406,214585
3,S2,1,5.712509,214586
4,S2,1,5.727005,214587


In [36]:
#Showing that all subjects and labels are present

label_counts = df_chest_eda.groupby(['subject', 'label']).size().reset_index(name='count')
print(label_counts)

   subject  label   count
0      S10      1  826000
1      S10      2  507500
2      S10      3  260400
3      S10      4  557200
4      S11      1  826000
5      S11      2  476000
6      S11      3  257600
7      S11      4  553701
8      S13      1  826001
9      S13      2  464800
10     S13      3  267400
11     S13      4  556499
12     S14      1  826000
13     S14      2  472500
14     S14      3  260401
15     S14      4  555800
16     S15      1  822500
17     S15      2  480200
18     S15      3  260400
19     S15      4  555799
20     S16      1  826000
21     S16      2  471101
22     S16      3  257600
23     S16      4  554399
24     S17      1  826700
25     S17      2  506100
26     S17      3  260400
27     S17      4  511700
28      S2      1  800800
29      S2      2  430500
30      S2      3  253400
31      S2      4  537599
32      S3      1  798000
33      S3      2  448000
34      S3      3  262500
35      S3      4  546001
36      S4      1  810601
37      S4  

In [37]:
# Chest Temp

chest_temp_rows = []

for subject in subject_ids:
    file_path = os.path.join(base_path, subject, f"{subject}.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")

    signal = data['signal']['chest']['Temp'].flatten()
    labels = data['label']
    N = min(len(signal), len(labels))

    chest_temp_rows.extend([
        {"subject": subject, "label": int(labels[i]), "value": float(signal[i]), "sample": i}
        for i in range(N) if labels[i] in [1, 2, 3, 4]
    ])

df_chest_temp = pd.DataFrame(chest_temp_rows)
print("Chest Temp shape:", df_chest_temp.shape)
df_chest_temp.head()


Chest Temp shape: (31470603, 4)


Unnamed: 0,subject,label,value,sample
0,S2,1,29.083618,214583
1,S2,1,29.122437,214584
2,S2,1,29.115234,214585
3,S2,1,29.126709,214586
4,S2,1,29.100861,214587


In [38]:
#Showing that all subjects and labels are present

label_counts = df_chest_temp.groupby(['subject', 'label']).size().reset_index(name='count')
print(label_counts)

   subject  label   count
0      S10      1  826000
1      S10      2  507500
2      S10      3  260400
3      S10      4  557200
4      S11      1  826000
5      S11      2  476000
6      S11      3  257600
7      S11      4  553701
8      S13      1  826001
9      S13      2  464800
10     S13      3  267400
11     S13      4  556499
12     S14      1  826000
13     S14      2  472500
14     S14      3  260401
15     S14      4  555800
16     S15      1  822500
17     S15      2  480200
18     S15      3  260400
19     S15      4  555799
20     S16      1  826000
21     S16      2  471101
22     S16      3  257600
23     S16      4  554399
24     S17      1  826700
25     S17      2  506100
26     S17      3  260400
27     S17      4  511700
28      S2      1  800800
29      S2      2  430500
30      S2      3  253400
31      S2      4  537599
32      S3      1  798000
33      S3      2  448000
34      S3      3  262500
35      S3      4  546001
36      S4      1  810601
37      S4  

In [26]:
# Chest Resp

chest_resp_rows = []

for subject in subject_ids:
    file_path = os.path.join(base_path, subject, f"{subject}.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")

    signal = data['signal']['chest']['Resp'].flatten()
    labels = data['label']
    N = min(len(signal), len(labels))

    chest_resp_rows.extend([
        {"subject": subject, "label": int(labels[i]), "value": float(signal[i]), "sample": i}
        for i in range(N) if labels[i] in [1, 2, 3, 4]
    ])

df_chest_resp = pd.DataFrame(chest_resp_rows)
print("Chest Resp shape:", df_chest_resp.shape)
df_chest_resp.head()


Chest Resp shape: (31470603, 4)


Unnamed: 0,subject,label,value,sample
0,S2,1,1.191711,214583
1,S2,1,1.139832,214584
2,S2,1,1.141357,214585
3,S2,1,1.15509,214586
4,S2,1,1.133728,214587


In [27]:
#Showing that all subjects and labels are present

label_counts = df_chest_resp.groupby(['subject', 'label']).size().reset_index(name='count')
print(label_counts)

   subject  label   count
0      S10      1  826000
1      S10      2  507500
2      S10      3  260400
3      S10      4  557200
4      S11      1  826000
5      S11      2  476000
6      S11      3  257600
7      S11      4  553701
8      S13      1  826001
9      S13      2  464800
10     S13      3  267400
11     S13      4  556499
12     S14      1  826000
13     S14      2  472500
14     S14      3  260401
15     S14      4  555800
16     S15      1  822500
17     S15      2  480200
18     S15      3  260400
19     S15      4  555799
20     S16      1  826000
21     S16      2  471101
22     S16      3  257600
23     S16      4  554399
24     S17      1  826700
25     S17      2  506100
26     S17      3  260400
27     S17      4  511700
28      S2      1  800800
29      S2      2  430500
30      S2      3  253400
31      S2      4  537599
32      S3      1  798000
33      S3      2  448000
34      S3      3  262500
35      S3      4  546001
36      S4      1  810601
37      S4  

In [28]:
# Chest ACC

chest_acc_rows = []

for subject in subject_ids:
    file_path = os.path.join(base_path, subject, f"{subject}.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")

    acc = data['signal']['chest']['ACC']
    labels = data['label']
    N = min(len(acc), len(labels))

    chest_acc_rows.extend([
        {
            "subject": subject,
            "label": int(labels[i]),
            "ACC_x": float(acc[i][0]),
            "ACC_y": float(acc[i][1]),
            "ACC_z": float(acc[i][2]),
            "sample": i
        }
        for i in range(N) if labels[i] in [1, 2, 3, 4]
    ])

df_chest_acc = pd.DataFrame(chest_acc_rows)
print("Chest ACC shape:", df_chest_acc.shape)
df_chest_acc.head()


Chest ACC shape: (31470603, 6)


Unnamed: 0,subject,label,ACC_x,ACC_y,ACC_z,sample
0,S2,1,0.8914,-0.1102,-0.2576,214583
1,S2,1,0.8926,-0.1086,-0.2544,214584
2,S2,1,0.893,-0.1094,-0.258,214585
3,S2,1,0.8934,-0.1082,-0.2538,214586
4,S2,1,0.893,-0.1096,-0.257,214587


In [29]:
#Showing that all subjects and labels are present

label_counts = df_chest_acc.groupby(['subject', 'label']).size().reset_index(name='count')
print(label_counts)

   subject  label   count
0      S10      1  826000
1      S10      2  507500
2      S10      3  260400
3      S10      4  557200
4      S11      1  826000
5      S11      2  476000
6      S11      3  257600
7      S11      4  553701
8      S13      1  826001
9      S13      2  464800
10     S13      3  267400
11     S13      4  556499
12     S14      1  826000
13     S14      2  472500
14     S14      3  260401
15     S14      4  555800
16     S15      1  822500
17     S15      2  480200
18     S15      3  260400
19     S15      4  555799
20     S16      1  826000
21     S16      2  471101
22     S16      3  257600
23     S16      4  554399
24     S17      1  826700
25     S17      2  506100
26     S17      3  260400
27     S17      4  511700
28      S2      1  800800
29      S2      2  430500
30      S2      3  253400
31      S2      4  537599
32      S3      1  798000
33      S3      2  448000
34      S3      3  262500
35      S3      4  546001
36      S4      1  810601
37      S4  

## Now Wrist Data

In [30]:
# Wrist BVP

wrist_bvp_rows = []

for subject in subject_ids:
    file_path = os.path.join(base_path, subject, f"{subject}.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")

    bvp = data['signal']['wrist']['BVP']
    labels = data['label']
    
    # Downsample labels to match BVP length
    factor = len(labels) // len(bvp)
    labels_ds = np.array([np.bincount(labels[i*factor:(i+1)*factor]).argmax() for i in range(len(bvp))])

    wrist_bvp_rows.extend([
        {"subject": subject, "label": int(labels_ds[i]), "value": float(bvp[i]), "sample": i}
        for i in range(len(bvp)) if labels_ds[i] in [1, 2, 3, 4]
    ])

df_wrist_bvp = pd.DataFrame(wrist_bvp_rows)
print("Wrist BVP shape:", df_wrist_bvp.shape)
df_wrist_bvp.head()


Wrist BVP shape: (2857002, 4)


Unnamed: 0,subject,label,value,sample
0,S2,1,28.52,21458
1,S2,1,-47.98,21459
2,S2,1,-113.26,21460
3,S2,1,-157.08,21461
4,S2,1,-183.7,21462


In [31]:
#Showing that all subjects and labels are present

label_counts = df_wrist_bvp.groupby(['subject', 'label']).size().reset_index(name='count')
print(label_counts)

   subject  label  count
0      S10      1  82600
1      S10      2  50750
2      S10      3  26040
3      S10      4  35214
4      S11      1  82600
5      S11      2  47600
6      S11      3  25760
7      S11      4  34148
8      S13      1  82600
9      S13      2  46479
10     S13      3  26739
11     S13      4  35873
12     S14      1  82600
13     S14      2  47250
14     S14      3  26040
15     S14      4  34598
16     S15      1  82250
17     S15      2  48020
18     S15      3  26040
19     S15      4  36092
20     S16      1  82600
21     S16      2  47110
22     S16      3  25760
23     S16      4  37028
24     S17      1  82670
25     S17      2  50610
26     S17      3  26040
27     S17      4  34324
28      S2      1  80080
29      S2      2  43050
30      S2      3  25340
31      S2      4  31668
32      S3      1  79800
33      S3      2  44800
34      S3      3  26250
35      S3      4  36442
36      S4      1  81060
37      S4      2  44450
38      S4      3  26040


In [32]:
# Wrist EDA

wrist_eda_rows = []

for subject in subject_ids:
    file_path = os.path.join(base_path, subject, f"{subject}.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")

    eda = data['signal']['wrist']['EDA']
    labels = data['label']

    factor = len(labels) // len(eda)
    labels_ds = np.array([np.bincount(labels[i*factor:(i+1)*factor]).argmax() for i in range(len(eda))])

    wrist_eda_rows.extend([
        {"subject": subject, "label": int(labels_ds[i]), "value": float(eda[i]), "sample": i}
        for i in range(len(eda)) if labels_ds[i] in [1, 2, 3, 4]
    ])

df_wrist_eda = pd.DataFrame(wrist_eda_rows)
print("Wrist EDA shape:", df_wrist_eda.shape)
df_wrist_eda.head()


Wrist EDA shape: (179832, 4)


Unnamed: 0,subject,label,value,sample
0,S2,1,1.645664,1226
1,S2,1,1.640539,1227
2,S2,1,1.634132,1228
3,S2,1,1.614912,1229
4,S2,1,1.591848,1230


In [33]:
#Showing that all subjects and labels are present

label_counts = df_wrist_eda.groupby(['subject', 'label']).size().reset_index(name='count')
print(label_counts)

   subject  label  count
0      S10      1   4720
1      S10      2   2900
2      S10      3   1488
3      S10      4   3184
4      S11      1   4720
5      S11      2   2720
6      S11      3   1472
7      S11      4   3164
8      S13      1   4720
9      S13      2   2656
10     S13      3   1528
11     S13      4   3180
12     S14      1   4720
13     S14      2   2700
14     S14      3   1488
15     S14      4   3176
16     S15      1   4700
17     S15      2   2744
18     S15      3   1488
19     S15      4   3176
20     S16      1   4720
21     S16      2   2692
22     S16      3   1472
23     S16      4   3168
24     S17      1   4724
25     S17      2   2892
26     S17      3   1488
27     S17      4   2924
28      S2      1   4576
29      S2      2   2460
30      S2      3   1448
31      S2      4   3072
32      S3      1   4560
33      S3      2   2560
34      S3      3   1500
35      S3      4   3120
36      S4      1   4632
37      S4      2   2540
38      S4      3   1488


In [34]:
# Wrist Temp

wrist_temp_rows = []

for subject in subject_ids:
    file_path = os.path.join(base_path, subject, f"{subject}.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")

    temp = data['signal']['wrist']['TEMP']
    labels = data['label']

    factor = len(labels) // len(temp)
    labels_ds = np.array([np.bincount(labels[i*factor:(i+1)*factor]).argmax() for i in range(len(temp))])

    wrist_temp_rows.extend([
        {"subject": subject, "label": int(labels_ds[i]), "value": float(temp[i]), "sample": i}
        for i in range(len(temp)) if labels_ds[i] in [1, 2, 3, 4]
    ])

df_wrist_temp = pd.DataFrame(wrist_temp_rows)
print("Wrist Temp shape:", df_wrist_temp.shape)
df_wrist_temp.head()


Wrist Temp shape: (179832, 4)


Unnamed: 0,subject,label,value,sample
0,S2,1,35.81,1226
1,S2,1,35.81,1227
2,S2,1,35.81,1228
3,S2,1,35.81,1229
4,S2,1,35.81,1230


In [35]:
#Showing that all subjects and labels are present

label_counts = df_wrist_temp.groupby(['subject', 'label']).size().reset_index(name='count')
print(label_counts)

   subject  label  count
0      S10      1   4720
1      S10      2   2900
2      S10      3   1488
3      S10      4   3184
4      S11      1   4720
5      S11      2   2720
6      S11      3   1472
7      S11      4   3164
8      S13      1   4720
9      S13      2   2656
10     S13      3   1528
11     S13      4   3180
12     S14      1   4720
13     S14      2   2700
14     S14      3   1488
15     S14      4   3176
16     S15      1   4700
17     S15      2   2744
18     S15      3   1488
19     S15      4   3176
20     S16      1   4720
21     S16      2   2692
22     S16      3   1472
23     S16      4   3168
24     S17      1   4724
25     S17      2   2892
26     S17      3   1488
27     S17      4   2924
28      S2      1   4576
29      S2      2   2460
30      S2      3   1448
31      S2      4   3072
32      S3      1   4560
33      S3      2   2560
34      S3      3   1500
35      S3      4   3120
36      S4      1   4632
37      S4      2   2540
38      S4      3   1488


In [36]:
# Wrist ACC

wrist_acc_rows = []

for subject in subject_ids:
    file_path = os.path.join(base_path, subject, f"{subject}.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f, encoding="latin1")

    acc = data['signal']['wrist']['ACC']
    labels = data['label']

    factor = len(labels) // len(acc)
    labels_ds = np.array([np.bincount(labels[i*factor:(i+1)*factor]).argmax() for i in range(len(acc))])

    wrist_acc_rows.extend([
        {
            "subject": subject,
            "label": int(labels_ds[i]),
            "ACC_x": float(acc[i][0]),
            "ACC_y": float(acc[i][1]),
            "ACC_z": float(acc[i][2]),
            "sample": i
        }
        for i in range(len(acc)) if labels_ds[i] in [1, 2, 3, 4]
    ])

df_wrist_acc = pd.DataFrame(wrist_acc_rows)
print("Wrist ACC shape:", df_wrist_acc.shape)
df_wrist_acc.head()


Wrist ACC shape: (1481709, 6)


Unnamed: 0,subject,label,ACC_x,ACC_y,ACC_z,sample
0,S2,1,42.0,-21.0,39.0,10218
1,S2,1,43.0,-22.0,39.0,10219
2,S2,1,43.0,-22.0,41.0,10220
3,S2,1,44.0,-21.0,39.0,10221
4,S2,1,44.0,-21.0,40.0,10222


In [37]:
#Showing that all subjects and labels are present

label_counts = df_wrist_acc.groupby(['subject', 'label']).size().reset_index(name='count')
print(label_counts)

   subject  label  count
0      S10      1  39334
1      S10      2  24166
2      S10      3  12400
3      S10      4  25143
4      S11      1  39333
5      S11      2  22667
6      S11      3  12267
7      S11      4  24236
8      S13      1  39334
9      S13      2  22133
10     S13      3  12733
11     S13      4  25521
12     S14      1  39334
13     S14      2  22500
14     S14      3  12400
15     S14      4  24929
16     S15      1  39167
17     S15      2  22866
18     S15      3  12400
19     S15      4  25189
20     S16      1  39333
21     S16      2  22433
22     S16      3  12267
23     S16      4  26213
24     S17      1  39367
25     S17      2  24100
26     S17      3  12400
27     S17      4  24367
28      S2      1  38134
29      S2      2  20500
30      S2      3  12067
31      S2      4  24343
32      S3      1  38000
33      S3      2  21333
34      S3      3  12500
35      S3      4  26000
36      S4      1  38600
37      S4      2  21167
38      S4      3  12400
