In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import csv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import os
from tqdm import tqdm
from thop import profile, clever_format

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

def set_seed(seed=42):
    np.random.seed(seed)  # 设置 NumPy 的随机种子
    torch.manual_seed(seed)  # 设置 PyTorch 的随机种子
    torch.cuda.manual_seed(seed)  # 为当前 GPU 设置随机种子
    torch.cuda.manual_seed_all(seed)  # 为所有 GPU 设置随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作是确定性的
    torch.backends.cudnn.benchmark = False  # 关闭非确定性优化

set_seed(42)  # 调用函数，设置固定的随机种子


config = {
    'subjects_num': 12,
    'n_epochs': 20, 
    'batch_size': 50,
    'save_name': 'logs/DualFusion-{epoch:02d}-{val_acc:.2f}',
    'log_path1': 'logs/DualFusion_logs',  # 修改
    'num_class': 2 # 修改,二分类:0-清醒,1-疲劳
}

isIntraSub = False  # 修改


def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'



class EEG_IntraSub_Dataset(Dataset):
    def __init__(self, path, mode, test_sub):
        self.mode = mode
        sub_list = [i for i in range(config['subjects_num'])]
        data = []
        label = []

        
        for i in sub_list:
            data_sub = np.load(path + f'sub_{i}_eeg.npy')
            label_sub = np.load(path + f'sub_{i}_labels.npy')
            data.extend(data_sub)
            label.extend(label_sub)
            
        data = np.array(data)
        label = np.array(label).flatten()
        
        # 生成随机索引进行同步shuffle
        shuffle_idx = np.random.permutation(len(data))
        data = data[shuffle_idx]
        label = label[shuffle_idx]
    
        if mode == 'train':
            data = data[:int(len(data)*0.8)]
            label = label[:int(len(label)*0.8)]
       
        elif mode == 'val':
            data = data[int(len(data)*0.8):int(len(data)*0.9)]
            label = label[int(len(label)*0.8):int(len(label)*0.9)]
        
        elif mode == 'test':
            data = data[int(len(data)*0.9):]
            label = label[int(len(label)*0.9):]
        
        self.data = torch.FloatTensor(data)
        self.label = torch.LongTensor(label)

    def __len__(self):
        return len(self.data)  

    def __getitem__(self, index):
        return self.data[index], self.label[index]
        
class EEG_InterSub_Dataset(Dataset):
    def __init__(self, path, mode, test_sub):
        self.mode = mode
        self.test_sub = test_sub
        
        if mode == 'train' or mode == 'val':
            train_sub = [i for i in range(config['subjects_num'])]
            train_sub.remove(test_sub)
            data = []
            label = []
            for i in train_sub:
                data_sub = np.load(path + f'sub_{i}_eeg.npy')
                label_sub = np.load(path + f'sub_{i}_labels.npy')
                data.extend(data_sub)
                label.extend(label_sub)
                
            data = np.array(data)
            label = np.array(label).flatten()
            # 生成随机索引进行同步shuffle
            shuffle_idx = np.random.permutation(len(data))
            data = data[shuffle_idx]
            label = label[shuffle_idx]
    
            if mode == 'train':
                data = data[:int(len(data)*0.8)]
                label = label[:int(len(label)*0.8)]
                
            elif mode == 'val':
                data = data[int(len(data)*0.8):]
                label = label[int(len(label)*0.8):]
                   
        
        elif mode == 'test':
            
            data = np.load(path + f'sub_{test_sub}_eeg.npy')
            label = np.load(path + f'sub_{test_sub}_labels.npy')

        
        self.data = torch.FloatTensor(data)
        self.label = torch.LongTensor(label)      
    def __len__(self):
        return len(self.data)  # 返回数据的总个数

    def __getitem__(self, index):
        return self.data[index], self.label[index]


def prep_dataloader(path, mode, batch_size, test_sub, isIntraSub = False, njobs=1):
    if isIntraSub:
        print("IntraSub")
        dataset = EEG_IntraSub_Dataset(path, mode, test_sub)
    else:
        print("InterSub")
        dataset = EEG_InterSub_Dataset(path, mode, test_sub)
        
    dataloader = DataLoader(dataset, batch_size, shuffle=(mode == 'train'), drop_last=False, num_workers=njobs,
                            pin_memory=True)
    return dataloader


In [None]:

class GRUNet(pl.LightningModule):
    def __init__(self, input_size=17, hidden_size=64, num_layers=2, dropout_prob=0.3):
        super(GRUNet, self).__init__()
        
        self.gru1 = nn.GRU(input_size=input_size,
                          hidden_size=hidden_size,
                          num_layers=num_layers,
                          batch_first=True,
                          bidirectional=True,
                          dropout=dropout_prob)
        
        self.dropout = nn.Dropout(dropout_prob)
        
        # 双向GRU,所以特征维度是hidden_size*2
        self.fc1 = nn.Linear(hidden_size*2, 64)
        self.fc2 = nn.Linear(64, config['num_class'])
        
        # 初始化权重
        self._init_weights()
        
    def _init_weights(self):
        # 对卷积层使用He初始化
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            # 对全连接层使用Xavier初始化    
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            # 对BatchNorm层使用常数初始化
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        # x shape: (batch_size, channels, sequence_length)
        # GRU期望输入: (batch_size, sequence_length, input_size)
        x = x.permute(0, 2, 1)  # 调整维度顺序
        
        # GRU处理
        out, _ = self.gru1(x)
        
        # 只取序列的最后一个时间步
        out = out[:, -1, :]
        
        # 全连接层
        out = self.dropout(out)
        out = F.relu(self.fc1(out))
        out = self.dropout(out)
        out = self.fc2(out)
        
        return out
    
    
    def configure_optimizers(self):
        #optimizer = torch.optim.AdamW(self.parameters(), lr=0.001, weight_decay=1e-4)
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
     
        return optimizer 
      
    def training_step(self, batch):
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        self.log('training_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        loss = {'loss': loss}
        return loss

    def validation_step(self, batch):
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        self.log('val_loss', loss, prog_bar=True, logger=True, on_step=False, on_epoch=True)

    def test_step(self, batch):
        x, y = batch
        preds = self(x)
        
        y_pre = torch.argmax(F.log_softmax(preds, dim=1), dim=1)
        acc = accuracy_score(y.cpu(), y_pre.cpu())
        pre = precision_score(y.cpu(), y_pre.cpu(), average='weighted')
        recall = recall_score(y.cpu(), y_pre.cpu(), average='weighted')
        f1 = f1_score(y.cpu(), y_pre.cpu(), average='weighted')

        self.log('test_acc', acc)
        self.log('test_pre', pre)
        self.log('test_recall', recall)
        self.log('test_f1', f1)
        
           
        return {'test_acc': acc, 'test_pre': pre, 'test_recall': recall, 'test_f1': f1} 


def predict(model, dataloader):
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            preds = model(x)
            y_pre = torch.argmax(F.log_softmax(preds, dim=1), dim=1)
            acc = accuracy_score(y.cpu(), y_pre.cpu())
            pre = precision_score(y.cpu(), y_pre.cpu(), average='weighted')
            recall = recall_score(y.cpu(), y_pre.cpu(), average='weighted')
            f1 = f1_score(y.cpu(), y_pre.cpu(), average='weighted')

    return acc, pre, recall, f1
       
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filename=config['save_name'],
    save_top_k=1,
    mode='min',
    save_last=True
)

if __name__ == '__main__':
    tr_path = val_path = test_path =  "../Dataset/SEED-VIG-Subset/"
    device = get_device()
    isIntraSub = False
    
    model = GRUNet()
    input = torch.randn(1, 17, 384)
    flops, params = profile(model, inputs=(input,))
    flops, params = clever_format([flops, params], "%.3f")
    print("\033[42m" + f"FLOPs: {flops}, Parameters: {params}" + "\033[0m")
    
    AC,PR,RE,F1 = 0,0,0,0
    for test_sub in range(config['subjects_num']):
        tr_set = prep_dataloader(tr_path, 'train', config['batch_size'], test_sub, isIntraSub, njobs=6)
        val_set = prep_dataloader(val_path, 'val', config['batch_size'], test_sub, isIntraSub, njobs=6)
        test_set = prep_dataloader(test_path, 'test', config['batch_size'], test_sub, isIntraSub, njobs=1)
        model =  GRUNet().to(device)
        logger = TensorBoardLogger(config['log_path1'])#, config['log_path2'])
        trainer = Trainer(val_check_interval=1.0, max_epochs=config['n_epochs'], devices=[0], accelerator='gpu',
                        logger=logger,
                        callbacks=[
                            #EarlyStopping(monitor='val_loss', mode='min', check_on_train_epoch_end=True, patience=10, min_delta=1e-4),
                            checkpoint_callback
                        ]
                        )
        
        trainer.fit(model, train_dataloaders=tr_set, val_dataloaders=val_set)
        # 保存最终模型
        #trainer.save_checkpoint('FastAlertNet_final.ckpt')
        test_results = trainer.test(model, dataloaders=test_set)
       
        AC += test_results[0]['test_acc']
        PR += test_results[0]['test_pre']
        RE += test_results[0]['test_recall']
        F1 += test_results[0]['test_f1']
        
    AC /= config['subjects_num']
    PR /= config['subjects_num'] 
    RE /= config['subjects_num']
    F1 /= config['subjects_num']
    print(f"&{AC*100:.2f}",f"&{PR*100:.2f}",f"&{RE*100:.2f}",f"&{F1*100:.2f}")
        #f.write('\n')
        #f.close()