In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pandas as pd
import pickle
import numpy as np
import os
import math
from tqdm.auto import tqdm

In [None]:
#define you path to dataset
BASE_PATH = '/content/drive/MyDrive/Wesad/dataset/WESAD'
#define the subject id
sub_id = 11

**Fetch Data from Chest and Wrist from Dataset**

In [None]:
with open(os.path.join(BASE_PATH, 'S{}'.format(sub_id), 'S{}.pkl'.format(sub_id)), 'rb') as pickle_file:
  content = pickle.load(pickle_file, encoding='latin1')

  # Converting the pickle file to dictionary
  dataset_dict = {}
  for key in content['signal']['chest']:
    dataset_dict['signal_chest_'+key] = content['signal']['chest'][key]

  for key in content['signal']['wrist']:
    dataset_dict['signal_wrist_'+key] = content['signal']['wrist'][key]

  dataset_dict['label'] = content['label']


In [None]:
#dictionary for different physiological signals as per sampling frequency
chest_keys = {
      'ACC': 700, 
      'ECG': 700, 
      'EMG': 700, 
      'EDA': 700, 
      'Resp': 700, 
      'Temp': 700, 
  }
wrist_keys = {
    'BVP': 64, 
    'TEMP': 4, 
    'EDA': 4, 
    'ACC': 32
  }

**Index Normalization**

In [None]:
def norm_index(hrtz,index,base_hrtz=700): 
  numerator = index//math.ceil(base_hrtz/hrtz) 
  return numerator

In [None]:
len(dataset_dict['label'])

3663100

**Data Preprocessing**

In [None]:
records = []
for idx in range(len(dataset_dict['label'])):
  a_record = {}
  for key in chest_keys:

    if 'ACC' in key:
      a_record['signal_chest_'+key] = np.mean(dataset_dict['signal_chest_'+key][idx])
    else:
      a_record['signal_chest_'+key] = dataset_dict['signal_chest_'+key][idx][0]
  
  for key in wrist_keys:
    if 'ACC' in key:
      a_record['signal_wrist_'+key] = np.mean(dataset_dict['signal_wrist_'+key][norm_index(wrist_keys[key], idx)])
    else:
      a_record['signal_wrist_'+key] = dataset_dict['signal_wrist_'+key][norm_index(wrist_keys[key], idx)][0]

  a_record['label'] = dataset_dict['label'][idx]
  records.append(a_record)

In [None]:
sn_dataframe = pd.DataFrame(records)

In [None]:
sn_dataframe.shape

(3663100, 11)

In [None]:
#sn_dataframe.columns = cols
sn_dataframe.head()

Unnamed: 0,signal_chest_ACC,signal_chest_ECG,signal_chest_EMG,signal_chest_EDA,signal_chest_Resp,signal_chest_Temp,signal_wrist_BVP,signal_wrist_TEMP,signal_wrist_EDA,signal_wrist_ACC,label
0,0.9854,0.020096,-0.01387,6.607437,2.120972,33.520264,-12.14,34.0,4.4455,42.0,0
1,0.9146,0.03685,-0.002014,6.603241,2.113342,33.494537,-12.14,34.0,4.4455,42.0,0
2,0.786333,0.05365,-0.007599,6.60553,2.122498,33.508148,-12.14,34.0,4.4455,42.0,0
3,0.644733,0.05278,-0.026505,6.611252,2.125549,33.491516,-12.14,34.0,4.4455,42.0,0
4,0.5034,0.035751,-0.040421,6.609726,2.131653,33.52478,-12.14,34.0,4.4455,42.0,0


In [None]:
#correlation matrix
corr = sn_dataframe.corr()
corr.style.background_gradient(cmap='coolwarm')

Unnamed: 0,signal_chest_ACC,signal_chest_ECG,signal_chest_EMG,signal_chest_EDA,signal_chest_Resp,signal_chest_Temp,signal_wrist_BVP,signal_wrist_TEMP,signal_wrist_EDA,signal_wrist_ACC,label
signal_chest_ACC,1.0,0.002554,-0.000358,0.249505,0.046068,-0.299046,-0.000842,0.250769,0.302463,-0.063182,-0.35416
signal_chest_ECG,0.002554,1.0,-0.16811,-0.001746,0.000549,0.000531,-0.000153,-3e-05,2.1e-05,-0.000241,-5.3e-05
signal_chest_EMG,-0.000358,-0.16811,1.0,-0.005086,-1e-05,-0.009253,0.000362,0.008254,-0.001845,-0.001786,-0.002709
signal_chest_EDA,0.249505,-0.001746,-0.005086,1.0,0.00175,-0.01748,5e-05,-0.328025,0.712809,-0.124507,-0.053212
signal_chest_Resp,0.046068,0.000549,-1e-05,0.00175,1.0,-0.000245,-0.002226,0.000166,-0.000585,-0.001955,-0.003223
signal_chest_Temp,-0.299046,0.000531,-0.009253,-0.01748,-0.000245,1.0,0.000128,-0.735092,-0.528273,0.086926,0.284965
signal_wrist_BVP,-0.000842,-0.000153,0.000362,5e-05,-0.002226,0.000128,1.0,-4.8e-05,0.000194,0.008064,0.00031
signal_wrist_TEMP,0.250769,-3e-05,0.008254,-0.328025,0.000166,-0.735092,-4.8e-05,1.0,-0.001064,-0.099227,-0.330767
signal_wrist_EDA,0.302463,2.1e-05,-0.001845,0.712809,-0.000585,-0.528273,0.000194,-0.001064,1.0,0.030413,-0.251953
signal_wrist_ACC,-0.063182,-0.000241,-0.001786,-0.124507,-0.001955,0.086926,0.008064,-0.099227,0.030413,1.0,-0.284505


**Split Train Test Data**

In [None]:
from sklearn.model_selection import train_test_split
X = sn_dataframe.drop(columns=['label'])
Y = sn_dataframe.label
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.35)

**Logistic Regression**

In [None]:
from sklearn.linear_model import LogisticRegression
model = LogisticRegression(multi_class='multinomial',max_iter=200).fit(X_train, Y_train)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


In [None]:
pred = model.predict(X_test)
from sklearn.metrics import accuracy_score
accuracy_score(Y_test, pred) * 100

79.92449798570298

In [None]:
from sklearn.metrics import f1_score
f1_score(Y_test, pred, average='micro')

0.7992449798570298

**K-Nearest Neighbours**

In [None]:
from sklearn.neighbors import KNeighborsClassifier
knn_classifier = KNeighborsClassifier().fit(X_train, Y_train)

In [None]:
knn_score = knn_classifier.predict(X_test)
from sklearn.metrics import accuracy_score
accuracy_score(Y_test, knn_score) * 100

99.79876529247281

In [None]:
from sklearn.metrics import f1_score
f1_score(Y_test, knn_score, average='micro')

0.9979876529247281

**Decision Tree**


In [None]:
from sklearn import tree
dt = tree.DecisionTreeClassifier().fit(X_train,Y_train)

In [None]:
dt_score = dt.predict(X_test)
from sklearn.metrics import accuracy_score
accuracy_score(Y_test, dt_score) * 100

99.92496597339489

In [None]:
dt.get_n_leaves()

2785

In [None]:
from sklearn.metrics import f1_score
f1_score(Y_test, knn_score, average='micro')

0.9979876529247281

**Save Model**

In [None]:
pickle.dump(model,open("/content/drive/MyDrive/Pattern Recognition/Results/logistic_{}.pkl".format(sub_id),"wb"))
pickle.dump(knn_classifier,open("/content/drive/MyDrive/Pattern Recognition/Results/knn_{}.pkl".format(sub_id),"wb"))
pickle.dump(dt,open("/content/drive/MyDrive/Pattern Recognition/Results/tree_{}.pkl".format(sub_id),"wb"))