In [31]:
import os,re
import numpy as np
import pandas as pd
from scipy import io
import pickle
import hdf5storage as hdf5

In [2]:
root_path = 'D:\SEED'
feature_path = os.path.join(root_path,'ExtractedFeatures')
feature_path

'D:\\SEED\\ExtractedFeatures'

In [3]:
# sampling frequency
sf = 200

# hyper-parameters
num_trials = 15
num_subjects = 15
num_bands = 5
num_classes = 3
batch_size = 32

In [4]:
# get only subject data and sort for convenience
data = os.listdir(feature_path)
# only take subject files
data = [x for x in data if len(x.split("_")) == 2] 
data.sort(key = lambda x : int(x.split("_")[0]))
# 3 files per subject, each file contains recordings for 15 trials
assert (len(data) == 45)

# load one sample 
sample = hdf5.loadmat(os.path.join(feature_path, data[0]))

keys = list(sample.keys())
assert (len(keys) == (2*6*15+3)) # 3 meta keys
print("One sample shape: (num_channels, num_windows, num_bands)")
print(sample["de_LDS1"].shape)

# get all features averaged with LDS
features_LDS = keys[4::2]
print("Feature names LDS averaged")
print(features_LDS)
assert (len(features_LDS) == (15*6))

labels = hdf5.loadmat(os.path.join(feature_path, "label.mat"))
labels = np.squeeze(labels["label"] + np.ones(15, dtype=np.int8))
assert (labels.shape == (15,))
print(type(labels[0]))
print(labels)

One sample shape: (num_channels, num_windows, num_bands)
(62, 235, 5)
Feature names LDS averaged
['de_LDS1', 'psd_LDS1', 'dasm_LDS1', 'rasm_LDS1', 'asm_LDS1', 'dcau_LDS1', 'de_LDS2', 'psd_LDS2', 'dasm_LDS2', 'rasm_LDS2', 'asm_LDS2', 'dcau_LDS2', 'de_LDS3', 'psd_LDS3', 'dasm_LDS3', 'rasm_LDS3', 'asm_LDS3', 'dcau_LDS3', 'de_LDS4', 'psd_LDS4', 'dasm_LDS4', 'rasm_LDS4', 'asm_LDS4', 'dcau_LDS4', 'de_LDS5', 'psd_LDS5', 'dasm_LDS5', 'rasm_LDS5', 'asm_LDS5', 'dcau_LDS5', 'de_LDS6', 'psd_LDS6', 'dasm_LDS6', 'rasm_LDS6', 'asm_LDS6', 'dcau_LDS6', 'de_LDS7', 'psd_LDS7', 'dasm_LDS7', 'rasm_LDS7', 'asm_LDS7', 'dcau_LDS7', 'de_LDS8', 'psd_LDS8', 'dasm_LDS8', 'rasm_LDS8', 'asm_LDS8', 'dcau_LDS8', 'de_LDS9', 'psd_LDS9', 'dasm_LDS9', 'rasm_LDS9', 'asm_LDS9', 'dcau_LDS9', 'de_LDS10', 'psd_LDS10', 'dasm_LDS10', 'rasm_LDS10', 'asm_LDS10', 'dcau_LDS10', 'de_LDS11', 'psd_LDS11', 'dasm_LDS11', 'rasm_LDS11', 'asm_LDS11', 'dcau_LDS11', 'de_LDS12', 'psd_LDS12', 'dasm_LDS12', 'rasm_LDS12', 'asm_LDS12', 'dcau_LDS1

In [5]:
# get the range of values across samples from de_LDS feature
max_value = -1e18
min_value = 1e18
de_max_value = -1e18
de_min_value = 1e18

sample = hdf5.loadmat(os.path.join(feature_path, data[0]))
# get all the de_lds feature keys since for the final model I will only use de_lds features
de_keys = [key for key in sample.keys() if "de_LDS" in key]
assert (len(de_keys) == 15)

for sample in data:
  sample = hdf5.loadmat(os.path.join(feature_path, sample))
  for key in features_LDS:
    if key in de_keys:
      de_max_value = max(de_max_value, np.amax(sample[key]))
      de_min_value = min(de_min_value, np.amin(sample[key]))
    max_value = max(max_value, np.amax(sample[key]))
    min_value = min(min_value, np.amin(sample[key]))
print((min_value, max_value))
print(f'de range: {(de_min_value, de_max_value)}')

(-23.825413747036478, 1072646499245.6045)
de range: (10.567626836074302, 42.11366999020901)


#### Concatenate more windows in one training sample
*The above blocks treat a data sample as a single 1s window, however in order to effectively classify the input more time steps (1s windows) need to be taken into consideration.*

In [18]:
def get_wind_from_file(file:str, w_len, drop_incomplete=True):
  x = hdf5.loadmat(os.path.join(feature_path, file))
  all_data, all_labels = [], []
  total_num_wind = 0

  for i in range(0,90,6):
    trial = int(i//6) #0-indexed
    trial_data = []

    # take only de_lds feature
    assert (features_LDS[i] == f"de_LDS{trial+1}")
    f = x[features_LDS[i]] # 62*235*5
    num_wind_trial = f.shape[1]
    
    for j in range(0, num_wind_trial, w_len): # concat w_len samples
      if drop_incomplete is True and (j + w_len > num_wind_trial): break;
      window = f[:, j:j+w_len, :]
      assert (window.shape == (62, w_len, 5))
      window = np.reshape(window, (62, -1))
      assert (window.shape == (62, w_len*5))
      trial_data.append(window)
    
    trial_data = np.stack(trial_data, axis=0) # 47*62*25
    num_wind_trial = trial_data.shape[0] # 47,
    total_num_wind += num_wind_trial

    assert (trial_data.shape == (num_wind_trial, 62, w_len*num_bands))
    assert (np.amax(trial_data) <= de_max_value)
    assert (np.amin(trial_data) >= de_min_value)

    # assign to each window the corresponding trial label
    trial_labels = np.array(list([labels[trial]] * trial_data.shape[0])) # 47,
    assert (np.unique(trial_labels).shape == (1,))
    assert (trial_labels[0] == labels[trial])

    all_data.append(trial_data)
    all_labels.append(trial_labels)
    
  all_data = np.concatenate(all_data, axis=0)
  all_labels = np.concatenate(all_labels, axis=0)
  assert (all_data.shape == (total_num_wind, 62, num_bands*w_len))
  assert (all_labels.shape == (total_num_wind,))
  assert (np.amax(all_data) <= de_max_value)
  assert (np.amin(all_data) >= de_min_value)
  return all_data, all_labels

def get_wind_for_subject(file1, file2, file3, w_len):
  series1, l1 = get_wind_from_file(file1, w_len) # 675*62*25; 675,
  series2, l2 = get_wind_from_file(file2, w_len) # 675*62*25; 675,
  series3, l3 = get_wind_from_file(file3, w_len) # 675*62*25; 675,

  series = np.concatenate([series1, series2, series3], axis=0) # 2025*62*25
  l = np.concatenate([l1, l2, l3], axis=0) # 2025,
  assert (series.shape == (3*series1.shape[0], 62, w_len*num_bands))
  assert (l.shape == (3*series1.shape[0],))
  assert (np.amax(series) <= de_max_value)
  assert (np.amin(series) >= de_min_value)
  return series, l

# def get_more_windows_data(w_len):
w_len = 5
all_data, all_labels = [], []
dic_all_data = {}
dic_all_labels = {}
for i in range(0,45,3):
  subject_data, subject_labels = get_wind_for_subject(data[i], data[i+1], data[i+2], w_len)
  # with open(NPY_PATH + "/" + "npy_" + str((i+1)//3) + ".npy", 'wb') as f:
  #   np.save(f, subject_data)
  # with open(NPY_PATH + "/" + "npy_" + str((i+1)//3) + "_label.npy", 'wb') as f:
  #   np.save(f, subject_labels)
  dic_all_data[f'sub_{((i+1)//3)+1}'] = subject_data
  dic_all_labels[f'sub_{((i+1)//3)+1}'] = subject_labels
  all_data.append(subject_data)
  all_labels.append(subject_labels)
  
all_data = np.concatenate(all_data, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
assert (all_data.shape[0] == all_labels.shape[0])
# with open(NPY_PATH + "/" + "npy_all_subjects.npy", 'wb') as f:
#     np.save(f, all_data)
# with open(NPY_PATH + "/" + "npy_all_subjects_label.npy", 'wb') as f:
#     np.save(f, all_labels)

In [27]:
for key in dic_all_data.keys():
  print(key, dic_all_data[key].shape, dic_all_labels[key].shape)
print('total:', all_data.shape, all_labels.shape)

sub_1 (2025, 62, 25) (2025,)
sub_2 (2025, 62, 25) (2025,)
sub_3 (2025, 62, 25) (2025,)
sub_4 (2025, 62, 25) (2025,)
sub_5 (2025, 62, 25) (2025,)
sub_6 (2025, 62, 25) (2025,)
sub_7 (2025, 62, 25) (2025,)
sub_8 (2025, 62, 25) (2025,)
sub_9 (2025, 62, 25) (2025,)
sub_10 (2025, 62, 25) (2025,)
sub_11 (2025, 62, 25) (2025,)
sub_12 (2025, 62, 25) (2025,)
sub_13 (2025, 62, 25) (2025,)
sub_14 (2025, 62, 25) (2025,)
sub_15 (2025, 62, 25) (2025,)
total: (30375, 62, 25) (30375,)


In [28]:
dic_all_data.keys()

dict_keys(['sub_1', 'sub_2', 'sub_3', 'sub_4', 'sub_5', 'sub_6', 'sub_7', 'sub_8', 'sub_9', 'sub_10', 'sub_11', 'sub_12', 'sub_13', 'sub_14', 'sub_15'])

In [32]:
save_path = 'seed_processed'
if os.path.exists(save_path) is False:
  os.makedirs(save_path)
with open(save_path + "/" + "npy_all_subjects.npy", 'wb') as f:
    np.save(f, all_data)
with open(save_path + "/" + "npy_all_subjects_label.npy", 'wb') as f:
    np.save(f, all_labels)
# dic保存为mat文件
io.savemat(os.path.join(save_path,'dic_all_subjects.mat'), dic_all_data)
io.savemat(os.path.join(save_path,'dic_all_subjects_label.mat'), dic_all_labels)

In [37]:
data_dic = io.loadmat(os.path.join(save_path, 'dic_all_subjects.mat'))
label_dic = io.loadmat(os.path.join(save_path, 'dic_all_subjects_label.mat'))
for key in data_dic.keys():
    if key.endswith('__'): continue
    print(key, data_dic[key].shape, label_dic[key].shape)

sub_1 (2025, 62, 25) (1, 2025)
sub_2 (2025, 62, 25) (1, 2025)
sub_3 (2025, 62, 25) (1, 2025)
sub_4 (2025, 62, 25) (1, 2025)
sub_5 (2025, 62, 25) (1, 2025)
sub_6 (2025, 62, 25) (1, 2025)
sub_7 (2025, 62, 25) (1, 2025)
sub_8 (2025, 62, 25) (1, 2025)
sub_9 (2025, 62, 25) (1, 2025)
sub_10 (2025, 62, 25) (1, 2025)
sub_11 (2025, 62, 25) (1, 2025)
sub_12 (2025, 62, 25) (1, 2025)
sub_13 (2025, 62, 25) (1, 2025)
sub_14 (2025, 62, 25) (1, 2025)
sub_15 (2025, 62, 25) (1, 2025)


In [43]:
label_dic[key][0]

array([2, 2, 2, ..., 0, 0, 0], dtype=int16)