-
Notifications
You must be signed in to change notification settings - Fork 1
/
swd_utils.py
303 lines (269 loc) · 15.7 KB
/
swd_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
from swd_model import *
from read_TUSZ import *
import nitime.algorithms as tsa
import pandas as pd
from sklearn.model_selection import train_test_split
seed(162)
tf.random.set_seed(162)
def training(Xtrain, labels, test_set, test_labels, leave_n_out, train_num, target_path, in_size, out_size,
conf_mat_write=True):
"""
The function to conduct individual trainings
:param Xtrain: Input data that goes into the model during training
:param labels: Original input labels (one-hot)
:param test_set: Input data that goes into the model during testing
:param test_labels: Original test labels (one-hot)
:param leave_n_out: the number of patients that is left to testing
:param train_num: it counts the number from 0 to the number of folds in LOOCV conf.
:param target_path: the path that the model, checkpoints, and history will be saved, should end with /
:param in_size: the length of the input trial on the time/frequency axis
:param out_size: output dimension, e.g. 2 for one-hot with 2 classes
:param conf_mat_write: a boolean specifies whether to write confusion matrix
"""
# prepare the data
x_train, x_val, y_train, y_val = train_test_split(Xtrain, labels, stratify=labels, test_size=0.3, random_state=1)
Xtrain_II = x_train[:, 0, :] # Channel-1: F7-T3 (train)
Xtrain_V5 = x_train[:, 1, :] # Channel-2: F8-T4 (train)
xval_II = x_val[:, 0, :] # Channel-1: F7-T3 (validation)
xval_V5 = x_val[:, 1, :] # Channel-2: F8-T4 (validation)
# model definition
model = define_model(in_size, out_size)
# define callbacks
if not os.path.isdir(target_path):
os.mkdir(target_path)
checkpoint_filepath = "leave_" + str(leave_n_out) + "_out_training_number_" + str(train_num)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=target_path + "checkpoint_" + checkpoint_filepath,
save_weights_only=False,
monitor='val_f1_score', mode='max',
save_best_only=True)
stop_me = tf.keras.callbacks.EarlyStopping(monitor='val_f1_score', min_delta=0, patience=100, verbose=1, mode='max',
baseline=None, restore_best_weights=True)
where_am_I = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_f1_score', factor=0.1, patience=75, verbose=1,
mode='max', min_delta=0.001, cooldown=0, min_lr=0)
# model training
history = model.fit(x=[Xtrain_II, Xtrain_V5], y=y_train, epochs=900, batch_size=80, verbose=1,
validation_data=([xval_II, xval_V5], y_val),
callbacks=[model_checkpoint_callback, stop_me, where_am_I])
# save model and history info
pd.DataFrame.from_dict(history.history).to_csv(target_path + str(checkpoint_filepath) + ".csv", index=False)
model.save(target_path + "checkpoint_" + checkpoint_filepath)
print(
"Complete: training num: " + str(train_num) + " for leave " + str(leave_n_out) + " out cross validation conf.")
# print test results
model.evaluate(x=[test_set[:, 0, :], test_set[:, 1, :]], y=test_labels, verbose=1)
if conf_mat_write:
predictions = model.predict(x=[test_set[:, 0, :], test_set[:, 1, :]]) # get the predictions
predictions_decoded = (predictions[:, 0] < predictions[:, 1]) * 1 # decode one hot predictions to a vector
test_labels_decoded = np.where(test_labels == 1)[1] # get the original test labels as a vector
thE_matrix = tf.math.confusion_matrix(labels=test_labels_decoded, predictions=predictions_decoded)
print('Confusion Matrix: ')
print(thE_matrix)
def time_leave_n_out_cross_validation(leave_n_out, npz_dataset_location, target_path, sample_dim=(1, 2, 2500)):
"""
The function operates to realize leave n out cross validation for a dataset, time is the input of the proposed model
:param leave_n_out: for LOOCV this value should be 1
:param npz_dataset_location: the path/name of the windowed ready to train dataset
:param target_path:target folder where the model, checkpoint and history will be saved
:param sample_dim: the shape of one trial in the trial
"""
dataset = np.load(npz_dataset_location, allow_pickle=True)['patients'].item() # get the data
patients = list(dataset.keys()) # get the patient names
num_of_training = int(len(patients) / leave_n_out) # the num of training for leave n out cross validation conf.
for train in range(0, num_of_training):
test_patients = patients[train * leave_n_out:(train + 1) * leave_n_out]
train_patients = list(set(patients).difference(set(test_patients)))
train_patients.sort()
train_dataset = np.zeros(sample_dim, float) # dummy variable, initialization
train_labels = np.zeros(1, int) # dummy variable, initialization
test_dataset = np.zeros(sample_dim, float) # dummy variable, initialization
test_labels = np.zeros(1, int) # dummy variable, initialization
for pats in train_patients:
train_dataset = np.concatenate((train_dataset, dataset[pats]['data']))
train_labels = np.concatenate((train_labels, dataset[pats]['labels']))
for pats in test_patients:
test_dataset = np.concatenate((test_dataset, dataset[pats]['data']))
test_labels = np.concatenate((test_labels, dataset[pats]['labels']))
# gets rid of the dummy variable
train_dataset = train_dataset[1:, :, :]
train_labels = train_labels[1:]
test_dataset = test_dataset[1:, :, :]
test_labels = test_labels[1:]
# convert labels to one-hot
one_hot_train_labels = np.zeros((train_labels.size, int(train_labels.max()) + 1))
one_hot_train_labels[np.arange(train_labels.size), train_labels] = 1
one_hot_test_labels = np.zeros((test_labels.size, int(test_labels.max()) + 1))
one_hot_test_labels[np.arange(test_labels.size), test_labels] = 1
training(train_dataset, one_hot_train_labels, test_dataset, one_hot_test_labels, leave_n_out, train,
target_path,
in_size=sample_dim[2], out_size=2, conf_mat_write=True)
def multitaperpsd_leave_n_out_cross_validation(leave_n_out, npz_dataset_location, target_path, sample_dim=(1, 2, 1251)):
"""
The function operates to realize leave n out cross validation for a dataset,
power spectral density is the input of the proposed model
:param target_path:target folder where the model, checkpoint and history will be saved
:param leave_n_out: for LOOCV this value should be 1
:param npz_dataset_location: the path/name of the windowed ready to train dataset
:param sample_dim: the shape of one trial in the trial
"""
dataset = np.load(npz_dataset_location, allow_pickle=True)['patients'].item() # get the data
patients = list(dataset.keys()) # get the patient names
num_of_training = int(len(patients) / leave_n_out) # the num of training for leave n out cross validation conf.
for train in range(0, num_of_training):
test_patients = patients[train * leave_n_out:(train + 1) * leave_n_out]
train_patients = list(set(patients).difference(set(test_patients)))
train_patients.sort()
train_dataset = np.zeros(sample_dim, float) # dummy variable, initialization
train_labels = np.zeros(1, int) # dummy variable, initialization
test_dataset = np.zeros(sample_dim, float) # dummy variable, initialization
test_labels = np.zeros(1, int) # dummy variable, initialization
for pats in train_patients:
# applies multitaper and gets power spectral density as psd_mt
# original data sampling frequency=250 Hz
f, psd_mt, nu = tsa.multi_taper_psd(dataset[pats]['data'], adaptive=False, jackknife=False, Fs=250, NW=6)
# replaces zeros with epsilon to enable log operation
psd_mt = np.where(psd_mt > 0.0000000001, psd_mt, 0.0000000001)
train_dataset = np.concatenate((train_dataset, np.log10(psd_mt))) # gets the log10
train_labels = np.concatenate((train_labels, dataset[pats]['labels']))
for pats in test_patients:
# applies multitaper and gets power spectral density as psd_mt
# original data sampling frequency=250 Hz
f, psd_mt, nu = tsa.multi_taper_psd(dataset[pats]['data'], adaptive=False, jackknife=False, Fs=250, NW=6)
# replaces zeros with epsilon to enable log operation
psd_mt = np.where(psd_mt > 0.0000000001, psd_mt, 0.0000000001)
test_dataset = np.concatenate((test_dataset, np.log10(psd_mt))) # gets the log10
test_labels = np.concatenate((test_labels, dataset[pats]['labels']))
# gets rid of the dummy variable
train_dataset = train_dataset[1:, :, :]
train_labels = train_labels[1:]
test_dataset = test_dataset[1:, :, :]
test_labels = test_labels[1:]
# convert labels to one-hot
one_hot_train_labels = np.zeros((train_labels.size, int(train_labels.max()) + 1))
one_hot_train_labels[np.arange(train_labels.size), train_labels] = 1
one_hot_test_labels = np.zeros((test_labels.size, int(test_labels.max()) + 1))
one_hot_test_labels[np.arange(test_labels.size), test_labels] = 1
training(train_dataset, one_hot_train_labels, test_dataset, one_hot_test_labels, leave_n_out, train,
target_path,
in_size=sample_dim[2], out_size=2, conf_mat_write=True)
def assume(original_labels, predictions):
"""
This function applies the assumptions that are done in most of the Epylepsy Projects.
:param original_labels: Original labels of the test data (not one-hot)
:param predictions: Predicted labels of the test data (not one-hot)
:return: the labels which are processed by the assumptions
"""
assumed = predictions[:]
differences = np.where(original_labels != predictions)[0]
differences.sort()
for difference in differences:
if original_labels[difference - 1] != original_labels[difference + 1]:
# means there is transition from 0 to 1 or 1 to 0, thus assumptions are applicable
assumed[difference] = original_labels[difference]
elif original_labels[difference - 1] and original_labels[difference + 1]:
# means there is a 11011 where the middle 0 should assumed to be 1
assumed[difference] = 1
return assumed
def apply_assumptions(models_folder, dataset_path, sample_dim=(1, 2, 1251), is_time=False):
"""
This function applies assumption to all LOOCV test sets and corresponding model predictions
:param models_folder: The folder where the model resides in
:param dataset_path: The path where the training dataset resides in
:param sample_dim: The sample dimension of the MODEL input
:param is_time: Wheter or not the model input is time
"""
dataset = np.load(dataset_path, allow_pickle=True)['patients'].item()
patients = list(dataset.keys())
num_of_set = len(patients)
num_of_training = int(num_of_set)
for train_num in range(0, num_of_training):
test_patients = patients[train_num:(train_num + 1)]
test_dataset = np.zeros(sample_dim, float) # dummy variable, initialization
test_labels = np.zeros(1, int) # dummy variable, initialization
data = dataset[test_patients[0]]['data']
if not is_time:
data = data2input(data)
test_dataset = np.concatenate((test_dataset, data))
test_labels = np.concatenate((test_labels, dataset[test_patients[0]]['labels']))
# gets rid of the dummy variable
test_dataset = test_dataset[1:, :, :]
test_labels = test_labels[1:]
# convert labels to one hot
one_hot_test_labels = np.zeros((test_labels.size, int(test_labels.max()) + 1))
one_hot_test_labels[np.arange(test_labels.size), test_labels] = 1
# eval
model_path = models_folder + "/checkpoint_leave_" + str(1) + "_out_training_number_" + str(train_num) + "/"
model = tf.keras.models.load_model(model_path, custom_objects={"F1Score": tfa.metrics.F1Score}, compile=True)
print(model_path)
# model.evaluate(x=[test_dataset[:, 0, :], test_dataset[:, 1, :]], y=one_hot_test_labels, verbose=1,sample_weight=None)
# confusion matrix
predictions = model.predict(x=[test_dataset[:, 0, :], test_dataset[:, 1, :]])
predictions_decoded = (predictions[:, 0] < predictions[:, 1]) * 1
test_labels_decoded = np.where(one_hot_test_labels == 1)[1]
assumed = assume(original_labels=test_labels_decoded, predictions=predictions_decoded)
thE_matrix2 = tf.math.confusion_matrix(labels=test_labels_decoded, predictions=assumed)
print('Confusion Matrix for: ' + test_patients[0])
print(thE_matrix2)
def data2input(data):
"""
This function is written so that there is no need to repeat the following lines while preprocessing RAW input again
:param data: The raw input data which the multitaper psd transform is not applied
:return: the logged multitaper psd
"""
f, psd_mt, nu = tsa.multi_taper_psd(data, adaptive=False, jackknife=False, Fs=250, NW=6)
psd_mt = np.where(psd_mt > 0.0000000001, psd_mt, 0.0000000001)
return np.log10(psd_mt)
def calculate_metrics(TN, FP, FN, TP):
"""
This function is written to calculate and print discussed metrics
:param TN: True Negatives
:param FP: False Positives
:param FN: False Negatives
:param TP: True Positives
"""
accuracy = 100*(TP + TN) / (TN + FP + FN + TP)
sensitivity = 100*TP / (FN + TP)
specifity = 100*TN / (FP + TN)
FD = FP * 360 / (TN + FP + FN + TP)
print('accuracy: ' + str(accuracy))
print('sensitivity: ' + str(sensitivity))
print('specifity: ' + str(specifity))
print('FD: ' + str(FD))
def adjust_absz_patients(absz_dir, target_name, config, Fs=250):
"""
This function is written to prapare a dataset which ready to go to model from individual absz patient EEG data
:param absz_dir: # the dir where absz .npz files exists, ex: file_dir = "../absz/"
:param target_name: # the dir where all patients will be stored ex: target_name = "absz_patients"
:return:
"""
current_dir = os.getcwd()
os.chdir(absz_dir)
# the following three can be defined inside the loop to generalize for all sampling frequency
sampling_frequency = Fs
windowing_time = config['window_width']
sample_per_label = sampling_frequency * windowing_time
overlapping = config['overlap']
shift = windowing_time * overlapping
all_patient_data = {}
for patients in sorted(os.listdir()):
patient_data = np.load(patients, allow_pickle=True)['patient'].item()
data = []
labels = []
for sessions in patient_data.keys():
for records in patient_data[sessions].keys():
record = patient_data[sessions][records]
num_of_labels = len(record['label'])
F7_T3 = record['data'][10, :] - record['data'][12, :]
F8_T4 = record['data'][11, :] - record['data'][13, :]
label = (record['label'][:] == 5) + 0
for row in range(0, num_of_labels):
eeg_signal_beginning = int(row * (windowing_time - shift) * sampling_frequency)
eeg_signal_end = int(eeg_signal_beginning + sample_per_label)
data.append(
[F7_T3[eeg_signal_beginning:eeg_signal_end], F8_T4[eeg_signal_beginning:eeg_signal_end]])
labels.append(label[row])
all_patient_data[patients] = {}
all_patient_data[patients]['data'] = np.array(data)
all_patient_data[patients]['labels'] = np.array(labels)
np.savez(target_name + '.npz', patients=all_patient_data)
os.chdir(current_dir)