In [1]:
import numpy as np
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [None]:
# 加载数据
data_pre = np.load('stft_feature_matrix_pre.npy')
data_inter = np.load('stft_feature_matrix_inter.npy')


In [4]:
data_pre = data_pre.reshape(22 * 641, 12607).T
data_inter = data_inter.reshape(22 * 641, 12607).T

#转为时间，频率*通道

In [5]:
# 创建标签
labels_pre = np.ones(data_pre.shape[0])   # 癫痫前期为1
labels_inter = np.zeros(data_inter.shape[0])  # 癫痫间期为0


In [6]:
# 合并数据标签
data = np.concatenate((data_pre, data_inter), axis=0)
labels = np.concatenate((labels_pre, labels_inter), axis=0)

In [8]:
# 划分训练集30%，验证集15%，测试集15%
X_train, X_temp, y_train, y_temp = train_test_split(data, labels, test_size=0.3, random_state=42)

X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)


In [3]:

channels = 22
frequency_points = 641

X_train = X_train.reshape(-1, channels, frequency_points, 1)
X_val = X_val.reshape(-1, channels, frequency_points, 1)
X_test = X_test.reshape(-1, channels, frequency_points, 1)



In [4]:
#保存划分数据
np.save('X_train.npy', X_train)
np.save('y_train.npy', y_train)
np.save('X_val.npy', X_val)
np.save('y_val.npy', y_val)
np.save('X_test.npy', X_test)
np.save('y_test.npy', y_test)

In [2]:
#读取划分数据
X_train = np.load('X_train.npy')
y_train = np.load('y_train.npy')
X_val = np.load('X_val.npy')
y_val = np.load('y_val.npy')
X_test = np.load('X_test.npy')
y_test = np.load('y_test.npy')

In [6]:
class CNNUnit(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(CNNUnit, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class CNNLayer(nn.Module):
    def __init__(self, in_channels, out_channels, num_units, stride=1):
        super(CNNLayer, self).__init__()
        units = []
        for _ in range(num_units):
            units.append(CNNUnit(in_channels, out_channels, stride))
            in_channels = out_channels  # 更新输入通道数
        self.units = nn.Sequential(*units)

    def forward(self, x):
        return self.units(x)

class TransformerBlock(nn.Module):
    def __init__(self, feature_size, num_heads, dropout_rate):
        super(TransformerBlock, self).__init__()
        self.transformer = nn.TransformerEncoderLayer(
            d_model=feature_size,
            nhead=num_heads,
            dropout=dropout_rate
        )

    def forward(self, x):
        return self.transformer(x)

class HybridCNNTransformer(nn.Module):
    def __init__(self):
        super(HybridCNNTransformer, self).__init__()
        # CNN Unit (assuming input channels is 1 for STFT)
        self.cnn_unit = CNNLayer(in_channels=channels, out_channels=16, num_units=3)
        
        # CNN and Transformer layers
        self.cnn1 = CNNLayer(in_channels=16, out_channels=24, num_units=1, stride=2)
        self.transformer1 = TransformerBlock(feature_size=24, num_heads=8, dropout_rate=0.1)
        
        self.cnn2 = CNNLayer(in_channels=24, out_channels=32, num_units=1, stride=2)
        self.transformer2 = TransformerBlock(feature_size=32, num_heads=8, dropout_rate=0.1)
        
        # Pooling layer
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Fully connected layers
        self.fc1 = nn.Linear(32, 512)
        self.fc2 = nn.Linear(512, 2)
        self.dropout = nn.Dropout(0.5)
        self.flatten = nn.Flatten()

    def forward(self, x):
        # Initial CNN processing
        x = self.cnn_unit(x)
        
        # First CNN and Transformer stage
        x = self.cnn1(x)
        x = x.flatten(2)  # Flatten the CNN features for the transformer
        x = x.permute(2, 0, 1)  # Reshape for transformer (Seq, Batch, Feature)
        x = self.transformer1(x)
        x = x.permute(1, 2, 0).unsqueeze(-1)  # Reshape back (Batch, Feature, Seq, 1)
        
        # Second CNN and Transformer stage
        x = self.cnn2(x)
        x = x.flatten(2)  # Flatten the CNN features for the transformer
        x = x.permute(2, 0, 1)  # Reshape for transformer
        x = self.transformer2(x)
        x = x.permute(1, 2, 0).unsqueeze(-1)  # Reshape back
        
        # Pooling and fully connected layers
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Create the model
model = HybridCNNTransformer()
print(model)

HybridCNNTransformer(
  (cnn_unit): CNNLayer(
    (units): Sequential(
      (0): CNNUnit(
        (conv): Conv2d(22, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (1): CNNUnit(
        (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (2): CNNUnit(
        (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
  )
  (cnn1): CNNLayer(
    (units): Sequential(
      (0): CNNUnit(
        (conv): Conv2d(16, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (bn): BatchNorm2d(24, eps=1e-05, momentum=0.1, 

In [7]:
# 将数据转换为PyTorch张量
train_data = TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).long())
val_data = TensorDataset(torch.from_numpy(X_val).float(), torch.from_numpy(y_val).long())

# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

# 初始化模型、损失函数和优化器
model = HybridCNNTransformer()
criterion = nn.CrossEntropyLoss()  # 用于分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [11]:

# 定义训练函数
def train(model, criterion, optimizer, train_loader, val_loader, epochs):
    best_val_loss = float('inf')  # 初始化最佳验证损失
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        
        for inputs, labels in train_loader:
            # 清空优化器
            optimizer.zero_grad()
            
            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # 反向传播和优化
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f'Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader)}')

        # 验证阶段
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
        # 计算平均验证损失
        avg_val_loss = val_loss / len(val_loader)
        print(f'Epoch {epoch+1}/{epochs} - Validation Loss: {avg_val_loss}')
        
        # 如果这个epoch的验证损失是迄今为止最佳的，保存模型
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model.pth')  # 保存模型参数
            print('Model saved')


In [14]:

# 训练模型
epochs = 4
print("start")
train(model, criterion, optimizer, train_loader, val_loader, epochs)

start


第一轮25分钟左右
第二轮36min

```
Epoch 1/2 - Loss: 0.5830323612992314
Epoch 1/2 - Validation Loss: 0.830383938550949
Epoch 2/2 - Loss: 0.5181694445402726
Epoch 2/2 - Validation Loss: 0.666413692633311
```
