In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import csv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import os
from tqdm import tqdm
from thop import profile, clever_format


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

def set_seed(seed=42):
    np.random.seed(seed)  # 设置 NumPy 的随机种子
    torch.manual_seed(seed)  # 设置 PyTorch 的随机种子
    torch.cuda.manual_seed(seed)  # 为当前 GPU 设置随机种子
    torch.cuda.manual_seed_all(seed)  # 为所有 GPU 设置随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作是确定性的
    torch.backends.cudnn.benchmark = False  # 关闭非确定性优化

set_seed(42)  # 调用函数，设置固定的随机种子


config = {
    'subjects_num': 12,
    'n_epochs': 100, 
    'batch_size': 64,
    'save_name': 'logs/EEGNet8.2-{epoch:02d}-{val_acc:.2f}',
    'log_path1': 'logs/EEGNet8.2_logs',  
    'num_class': 2 
}

isIntraSub = False  # 修改


def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'



class EEG_IntraSub_Dataset(Dataset):
    def __init__(self, path, mode, test_sub):
        self.mode = mode
        sub_list = [i for i in range(config['subjects_num'])]
        data = []
        label = []

        
        for i in sub_list:
            data_sub = np.load(path + f'sub_{i}_eeg.npy')
            label_sub = np.load(path + f'sub_{i}_labels.npy')
            data.extend(data_sub)
            label.extend(label_sub)
            
        data = np.array(data)
        label = np.array(label).flatten()
        
        # 生成随机索引进行同步shuffle
        shuffle_idx = np.random.permutation(len(data))
        data = data[shuffle_idx]
        label = label[shuffle_idx]
    
        if mode == 'train':
            data = data[:int(len(data)*0.7)]
            label = label[:int(len(label)*0.7)]
       
        elif mode == 'val':
            data = data[int(len(data)*0.7):int(len(data)*0.9)]
            label = label[int(len(label)*0.7):int(len(label)*0.9)]
        
        elif mode == 'test':
            data = data[int(len(data)*0.9):]
            label = label[int(len(label)*0.9):]
        
        self.data = torch.FloatTensor(data).unsqueeze(1)
        self.label = torch.LongTensor(label)

    def __len__(self):
        return len(self.data)  

    def __getitem__(self, index):
        return self.data[index], self.label[index]
        
class EEG_InterSub_Dataset(Dataset):
    def __init__(self, path, mode, test_sub):
        self.mode = mode
        self.test_sub = test_sub
        
        if mode == 'train' or mode == 'val':
            train_sub = [i for i in range(config['subjects_num'])]
            train_sub.remove(test_sub)
            data = []
            label = []
            for i in train_sub:
                data_sub = np.load(path + f'sub_{i}_eeg.npy')
                label_sub = np.load(path + f'sub_{i}_labels.npy')
                data.extend(data_sub)
                label.extend(label_sub)
                
            data = np.array(data)
            label = np.array(label).flatten()
            # 生成随机索引进行同步shuffle
            shuffle_idx = np.random.permutation(len(data))
            data = data[shuffle_idx]
            label = label[shuffle_idx]
    
            if mode == 'train':
                data = data[:int(len(data)*0.8)]
                label = label[:int(len(label)*0.8)]
                
            elif mode == 'val':
                data = data[int(len(data)*0.8):]
                label = label[int(len(label)*0.8):]
                   
        
        elif mode == 'test':
            
            data = np.load(path + f'sub_{test_sub}_eeg.npy')
            label = np.load(path + f'sub_{test_sub}_labels.npy')

        # 添加一个维度，使数据维度为 (batch_size, 1, 17, 384)
        self.data = torch.FloatTensor(data).unsqueeze(1)
        self.label = torch.LongTensor(label)
              
    def __len__(self):
        return len(self.data)  # 返回数据的总个数

    def __getitem__(self, index):
        return self.data[index], self.label[index]


def prep_dataloader(path, mode, batch_size, test_sub, isIntraSub = False, njobs=1):
    if isIntraSub:
        print("IntraSub")
        dataset = EEG_IntraSub_Dataset(path, mode, test_sub)
    else:
        print("InterSub")
        dataset = EEG_InterSub_Dataset(path, mode, test_sub)
        
    dataloader = DataLoader(dataset, batch_size, shuffle=(mode == 'train'), drop_last=False, num_workers=njobs,
                            pin_memory=True)
    return dataloader




In [3]:
# Based on EEGNet-8,2 https://github.com/amrzhd/EEGNet

class EEGNetModel(pl.LightningModule): # EEGNET-8,2
    def __init__(self, num_channels=17, num_time_points=384, temporal_kernel_size=32,
                 num_filters1=16, num_filters2=32, depth_multiplier=2, pool_size1=8, pool_size2=16, 
                 dropout_rate=0.5, spatial_norm=1, classifier_norm=0.25):
        super(EEGNetModel, self).__init__()
        # Calculating FC input features
        linear_size = (num_time_points//(pool_size1*pool_size2))*num_filters2

        # Temporal Filters
        self.block1 = nn.Sequential(
            nn.Conv2d(1, num_filters1, (1, temporal_kernel_size), padding='same', bias=False),
            nn.BatchNorm2d(num_filters1),
        )
        # Spatial Filters  
        self.block2 = nn.Sequential(
            nn.Conv2d(num_filters1, depth_multiplier * num_filters1, (num_channels, 1), groups=num_filters1, bias=False), # Depthwise Conv
            nn.BatchNorm2d(depth_multiplier * num_filters1),
            nn.ELU(),
            nn.AvgPool2d((1, pool_size1)),
            nn.Dropout(dropout_rate)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(depth_multiplier * num_filters1, num_filters2, (1, 16), groups=num_filters2, bias=False, padding='same'), # Separable Conv
            nn.Conv2d(num_filters2, num_filters2, kernel_size=1, bias=False), # Pointwise Conv
            nn.BatchNorm2d(num_filters2),
            nn.ELU(),
            nn.AvgPool2d((1, pool_size2)),
            nn.Dropout(dropout_rate)
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(linear_size, config['num_class'])

        # 对block2中的深度卷积层应用最大范数约束
        self._apply_max_norm(self.block2[0], spatial_norm)
        
        # 对全连接层应用最大范数约束
        self._apply_max_norm(self.fc, classifier_norm)

    def _apply_max_norm(self, layer, max_norm):
        with torch.no_grad():
            for name, param in layer.named_parameters():
                if 'weight' in name:
                    # 使用torch.renorm进行范数约束,保持权重不超过指定阈值
                    param.data = torch.renorm(param.data, p=2, dim=0, maxnorm=max_norm)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer #, [scheduler]
      
    def training_step(self, batch):
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        self.log('training_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        loss = {'loss': loss}
        return loss

    def validation_step(self, batch):
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        self.log('val_loss', loss, prog_bar=True, logger=True, on_step=False, on_epoch=True)

    def test_step(self, batch):
        x, y = batch
        preds = self(x)
        
        y_pre = torch.argmax(F.log_softmax(preds, dim=1), dim=1)
        acc = accuracy_score(y.cpu(), y_pre.cpu())
        pre = precision_score(y.cpu(), y_pre.cpu(), average='weighted')
        recall = recall_score(y.cpu(), y_pre.cpu(), average='weighted')
        f1 = f1_score(y.cpu(), y_pre.cpu(), average='weighted')

        self.log('test_acc', acc)
        self.log('test_pre', pre)
        self.log('test_recall', recall)
        self.log('test_f1', f1)
        
           
        return {'test_acc': acc, 'test_pre': pre, 'test_recall': recall, 'test_f1': f1} 


def predict(model, dataloader):
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            preds = model(x)
            y_pre = torch.argmax(F.log_softmax(preds, dim=1), dim=1)
            acc = accuracy_score(y.cpu(), y_pre.cpu())
            pre = precision_score(y.cpu(), y_pre.cpu(), average='weighted')
            recall = recall_score(y.cpu(), y_pre.cpu(), average='weighted')
            f1 = f1_score(y.cpu(), y_pre.cpu(), average='weighted')

    return acc, pre, recall, f1
       
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filename=config['save_name'],
    save_top_k=1,
    mode='min',
    save_last=True
)

In [None]:
if __name__ == '__main__':
    tr_path = val_path = test_path =  "/home/jie/Program/872/Dataset/SEED-VIG-Subset/"
    device = get_device()
    
    
    model = EEGNetModel()
    input = torch.randn(1, 1,17, 384)
    flops, params = profile(model, inputs=(input,))
    flops, params = clever_format([flops, params], "%.3f")
    print("\033[42m" + f"FLOPs: {flops}, Parameters: {params}" + "\033[0m")
   
    AC,PR,RE,F1 = 0,0,0,0
    for test_sub in range(config['subjects_num']):
        tr_set = prep_dataloader(tr_path, 'train', config['batch_size'], test_sub, isIntraSub, njobs=6)
        val_set = prep_dataloader(val_path, 'val', config['batch_size'], test_sub, isIntraSub, njobs=6)
        test_set = prep_dataloader(test_path, 'test', config['batch_size'], test_sub, isIntraSub, njobs=1)
        model =  EEGNetModel().to(device)
        logger = TensorBoardLogger(config['log_path1'])#, config['log_path2'])
        trainer = Trainer(val_check_interval=1.0, max_epochs=config['n_epochs'], devices=[0], accelerator='gpu',
                        logger=logger,
                        callbacks=[
                            #EarlyStopping(monitor='val_loss', mode='min', check_on_train_epoch_end=True, patience=10, min_delta=1e-4),
                            checkpoint_callback
                        ]
                        )
        
        trainer.fit(model, train_dataloaders=tr_set, val_dataloaders=val_set)
        # 保存最终模型
        #trainer.save_checkpoint('FastAlertNet_final.ckpt')
        test_results = trainer.test(model, dataloaders=test_set)
       
        AC += test_results[0]['test_acc']
        PR += test_results[0]['test_pre']
        RE += test_results[0]['test_recall']
        F1 += test_results[0]['test_f1']
        
    AC /= config['subjects_num']
    PR /= config['subjects_num'] 
    RE /= config['subjects_num']
    F1 /= config['subjects_num']
    print(f"&{AC*100:.2f}",f"&{PR*100:.2f}",f"&{RE*100:.2f}",f"&{F1*100:.2f}")