In [2]:
# @NOTE: 原始数据读取在extractFeture.py中
#from data import extractFeture
import time
import numpy as np 
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score,f1_score
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE
from src.data.loaddata import load_data
import os
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

### 🔵 数据预处理

In [5]:
# 记录数据预处理开始时间
data_preprocess_start_time = time.time()


subject_ids = [1,2,3]
# 打印当前目录
print(os.getcwd())
base_path = "src/data" # 数据存放路径
# 初始化 all_X 和 all_y 为空列表，用于存储所有主题的 X 和 y 数据
all_X = []
all_y = []

# 遍历所有主题
for subject_id in subject_ids:
    print(f"Loading data for subject {subject_id}...")
    # 加载数据
    X, y = load_data(subject_id, base_path)
    all_X.append(X)
    all_y.append(y)


# @TODO:23个通道的特征只是简单的拼接在一起，没有用到通道之间的关系
# 合并 all_X 和 all_y, vstack是垂直合并, concatenate是水平合并
# @DATA: X.shape = (27600,18), y.shape = (27600,) # (27600,18) -> (1200 , 23, 18) -> (1200 , 23 , 6 , 3)
X = np.vstack(all_X)
# @DATA: y.shape = (27600,)
y = np.concatenate(all_y)

# 记录数据预处理结束时间
data_preprocess_end_time = time.time()

# 计算数据预处理耗时
data_preprocess_time = data_preprocess_end_time - data_preprocess_start_time
print(f"Data preprocess time: {data_preprocess_time:.2f} seconds")


o:\eeg\epilepsy_EEG_analysis_code
Loading data for subject 1...


Loading EDF files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting EDF parameters from o:\eeg\epilepsy_EEG_analysis_code\src\data\chb01\chb01_03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 921599  =      0.000 ...  3599.996 secs...


  raw = mne.io.read_raw_edf(file_name, preload=True)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 845 samples (3.301 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.5s


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


Processing windows:  20%|█▉        | 178/900 [04:20<17:34,  1.46s/it]
Loading EDF files:   0%|          | 0/1 [04:21<?, ?it/s]


KeyboardInterrupt: 

In [3]:
# 记录过采样开始时间
oversampling_start_time = time.time()

# 初始化 SMOTE 实例
smote = SMOTE()

# @NOTE: SMOTE: Synthetic Minority Over-sampling Technique
# 1. 找到标签 y 中的少数类（例如 y=1）
# 2. 在少数类样本的特征空间中，通过现有样本之间的插值生成新的样本
# 3. 生成的新样本与少数类样本相似，从而增加少数类样本数量，平衡类别分布

# 应用 SMOTE 过采样
# @NOTE: 过采样平衡数据
# @NOTE: 过采样：对少数类样本进行插值，增加样本数量，使得少数类样本与多数类样本数量接近相等（不超原数据两倍）
# @NOTE: y = 1/0
X_resampled, y_resampled = smote.fit_resample(X, y)

# 记录过采样结束时间
oversampling_end_time = time.time()

# 计算过采样耗时
oversampling_time = oversampling_end_time - oversampling_start_time
print(f"Oversampling time: {oversampling_time:.2f} seconds")

# 分割处理后的数据集
# @TODO:了解数据内容与格式
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.3, random_state=0)
# @DATA: X_train.shape = (*,391), X_test.shape = (*,391), y_train.shape = (*,), y_test.shape = (*,)
# print(X_train.shape, X_test.shape, y_train.shape, y_test.shape) # (38189, 391) (16367, 391) (38189,) (16367,)


Oversampling time: 0.31 seconds


In [4]:

# 记录训练开始时间
start_time = time.time()

# 训练决策树分类器
# @IDEA:换成1DCNN
clf = DecisionTreeClassifier(random_state=0)
clf.fit(X_train, y_train)

# 记录训练结束时间
end_time = time.time()

# 计算训练耗时
train_time = end_time - start_time
print(f"Training time: {train_time:.2f} seconds")

# 对测试集进行预测
y_pred = clf.predict(X_test)

# 评估模型
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

# 计算 F1 分数
f1 = f1_score(y_test, y_pred)
print(f"F1 分数: {f1}")

Training time: 0.75 seconds
Accuracy: 0.9776379299810595
F1 分数: 0.9777318082258457
