In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau

from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import csv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import os
from tqdm import tqdm
from thop import profile, clever_format
# 对于疲劳监测,Recall更为重要,因为我们更关注是否能检测出所有疲劳状态
# 标签0表示清醒状态,标签1表示疲劳状态
# 高Recall意味着能捕获到更多的疲劳状态,降低漏报率,这对安全性要求高的场景更重要
# Precision则关注预测为疲劳时的准确性,但可能会漏掉一些疲劳样本


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

def set_seed(seed=42):
    np.random.seed(seed)  # 设置 NumPy 的随机种子
    torch.manual_seed(seed)  # 设置 PyTorch 的随机种子
    torch.cuda.manual_seed(seed)  # 为当前 GPU 设置随机种子
    torch.cuda.manual_seed_all(seed)  # 为所有 GPU 设置随机种子
    torch.backends.cudnn.deterministic = True  # 确保卷积等操作是确定性的
    torch.backends.cudnn.benchmark = False  # 关闭非确定性优化

set_seed(42)  # 调用函数，设置固定的随机种子


config = {
    'subjects_num': 12,
    'n_epochs': 30, 
    'batch_size': 64,
    'save_name': 'logs/FastAlertNet-{epoch:02d}-{val_acc:.2f}',
    'log_path1': 'logs/FastAlertNet_logs',  # 修改
    'num_class': 2 # 修改,二分类:0-清醒,1-疲劳
}

isIntraSub = False  # 修改


def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'



class EEG_IntraSub_Dataset(Dataset):
    def __init__(self, path, mode, test_sub):
        self.mode = mode
        sub_list = [i for i in range(config['subjects_num'])]
        data = []
        label = []

        
        for i in sub_list:
            data_sub = np.load(path + f'sub_{i}_eeg.npy')
            label_sub = np.load(path + f'sub_{i}_labels.npy')
            data.extend(data_sub)
            label.extend(label_sub)
            
        data = np.array(data)
        label = np.array(label).flatten()
        
        # 生成随机索引进行同步shuffle
        shuffle_idx = np.random.permutation(len(data))
        data = data[shuffle_idx]
        label = label[shuffle_idx]
    
        if mode == 'train':
            data = data[:int(len(data)*0.8)]
            label = label[:int(len(label)*0.8)]
       
        elif mode == 'val':
            data = data[int(len(data)*0.8):int(len(data)*0.9)]
            label = label[int(len(label)*0.8):int(len(label)*0.9)]
        
        elif mode == 'test':
            data = data[int(len(data)*0.9):]
            label = label[int(len(label)*0.9):]
        
        self.data = torch.FloatTensor(data)
        self.label = torch.LongTensor(label)

    def __len__(self):
        return len(self.data)  

    def __getitem__(self, index):
        return self.data[index], self.label[index]
        
class EEG_InterSub_Dataset(Dataset):
    def __init__(self, path, mode, test_sub):
        self.mode = mode
        self.test_sub = test_sub
        
        if mode == 'train' or mode == 'val':
            train_sub = [i for i in range(config['subjects_num'])]
            train_sub.remove(test_sub)
            data = []
            label = []
            for i in train_sub:
                data_sub = np.load(path + f'sub_{i}_eeg.npy')
                label_sub = np.load(path + f'sub_{i}_labels.npy')
                data.extend(data_sub)
                label.extend(label_sub)
                
            data = np.array(data)
            label = np.array(label).flatten()
            # 生成随机索引进行同步shuffle
            shuffle_idx = np.random.permutation(len(data))
            data = data[shuffle_idx]
            label = label[shuffle_idx]
    
            if mode == 'train':
                data = data[:int(len(data)*0.9)]
                label = label[:int(len(label)*0.9)]
                
            elif mode == 'val':
                data = data[int(len(data)*0.9):]
                label = label[int(len(label)*0.9):]
                   
        
        elif mode == 'test':
            
            data = np.load(path + f'sub_{test_sub}_eeg.npy')
            label = np.load(path + f'sub_{test_sub}_labels.npy')

        
        self.data = torch.FloatTensor(data)
        self.label = torch.LongTensor(label)      
    def __len__(self):
        return len(self.data)  # 返回数据的总个数

    def __getitem__(self, index):
        return self.data[index], self.label[index]


def prep_dataloader(path, mode, batch_size, test_sub, isIntraSub = False, njobs=1):
    if isIntraSub:
        print("IntraSub")
        dataset = EEG_IntraSub_Dataset(path, mode, test_sub)
    else:
        print("InterSub")
        dataset = EEG_InterSub_Dataset(path, mode, test_sub)
        
    dataloader = DataLoader(dataset, batch_size, shuffle=(mode == 'train'), drop_last=False, num_workers=njobs,
                            pin_memory=True)
    return dataloader




In [4]:
# InterpretCNN Detail
class InterpretCNN(pl.LightningModule):
    
    def __init__(self, input_channels=17, input_length=384, num_classes=2):

        super(InterpretCNN, self).__init__()
        
        # Define activation function
        
        self.pointwise_conv = nn.Conv1d(
            in_channels=input_channels, 
            out_channels=8, 
            kernel_size=1, 
            stride=1, 
            padding=0
        )
        #self.bn1 = nn.BatchNorm1d(16)
        #self.relu1 = nn.ReLU()
        
        # Depthwise Convolution
        self.depthwise_conv = nn.Conv1d(
            in_channels=8, 
            out_channels=16, 
            kernel_size=3, 
            stride=1, 
            padding=1, 
            groups=8  # Depthwise Convolution
        )
        self.bn2 = nn.BatchNorm1d(16)
        self.relu2 = nn.ReLU()
        
        # Global Average Pooling
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
        
        # Dense Layer
        self.fc = nn.Sequential(
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, num_classes)
        )
        
    def forward(self, x):
        # Input: [batch_size, 30, 384]
        x = self.pointwise_conv(x)  # [batch_size, 16, 384]
        #x = self.bn1(x)
        #x = self.relu1(x)
        
        x = self.depthwise_conv(x)  # [batch_size, 32, 384]
        x = self.bn2(x)
        x = self.relu2(x)
        
        x = self.global_avg_pool(x)  # [batch_size, 32, 1]
        x = x.squeeze(-1)  # [batch_size, 32]
        
        x = self.fc(x)  # [batch_size, num_classes]
        return x


    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
      
    def training_step(self, batch):
        x, y = batch
        preds = self(x)
        #loss = self.arc_loss(preds, y)
        loss = F.cross_entropy(preds, y)
        self.log('training_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        loss = {'loss': loss}
        return loss

    def validation_step(self, batch):
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        #self.arc_loss(preds, y)
        #entropy = F.cross_entropy(preds, y)
        self.log('val_loss', loss, prog_bar=True, logger=True, on_step=False, on_epoch=True)

    def test_step(self, batch):
        x, y = batch
        preds = self(x)
        
        y_pre = torch.argmax(F.log_softmax(preds, dim=1), dim=1)
        acc = accuracy_score(y.cpu(), y_pre.cpu())
        pre = precision_score(y.cpu(), y_pre.cpu(), average='weighted')
        recall = recall_score(y.cpu(), y_pre.cpu(), average='weighted')
        f1 = f1_score(y.cpu(), y_pre.cpu(), average='weighted')

        self.log('test_acc', acc)
        self.log('test_pre', pre)
        self.log('test_recall', recall)
        self.log('test_f1', f1)
        
        return {'test_acc': acc, 'test_pre': pre, 'test_recall': recall, 'test_f1': f1} 
        #return {'acc': acc, 'pre': pre, 'recall': recall, 'f1': f1}

def predict(model, dataloader):
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            preds = model(x)
            print(preds)
            y_pre = torch.argmax(F.log_softmax(preds, dim=1), dim=1)
            acc = accuracy_score(y.cpu(), y_pre.cpu())
            pre = precision_score(y.cpu(), y_pre.cpu(), average='weighted')
            recall = recall_score(y.cpu(), y_pre.cpu(), average='weighted')
            f1 = f1_score(y.cpu(), y_pre.cpu(), average='weighted')

    return acc, pre, recall, f1
       
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filename=config['save_name'],
    save_top_k=1,
    mode='min',
    save_last=True
)

if __name__ == '__main__':
    tr_path = val_path = test_path =  "/home/jie/Program/872/Dataset/SEED-VIG-Subset/"

    device = get_device()
    model = InterpretCNN()
    input = torch.randn(1, 17, 384)
    flops, params = profile(model, inputs=(input,))
    flops, params = clever_format([flops, params], "%.3f")
    print("\033[42m" + f"FLOPs: {flops}, Parameters: {params}" + "\033[0m")
    
    for i in range(12):
        test_sub = i
        tr_set = prep_dataloader(tr_path, 'train', config['batch_size'], test_sub, isIntraSub=True, njobs=6)
        val_set = prep_dataloader(val_path, 'val', config['batch_size'], test_sub, isIntraSub=True, njobs=6)
        test_set = prep_dataloader(test_path, 'test', config['batch_size'], test_sub, isIntraSub=True, njobs=1)
        model =  InterpretCNN().to(device)
        logger = TensorBoardLogger(config['log_path1'])#, config['log_path2'])
        trainer = Trainer(val_check_interval=1.0, max_epochs=config['n_epochs'], devices=[0], accelerator='gpu',
                        logger=logger,
                        callbacks=[
                            #EarlyStopping(monitor='val_loss', mode='min', check_on_train_epoch_end=True, patience=10, min_delta=1e-4),
                            checkpoint_callback
                        ]
                        )
        
        trainer.fit(model, train_dataloaders=tr_set, val_dataloaders=val_set)
        # 保存最终模型
        trainer.save_checkpoint('InterpretCNN.ckpt')

        test_results = trainer.test(model, dataloaders=test_set)
        # 将测试结果写入文件
        f = open('InterpretCNN_test_results.txt', 'a')
        f.write('Subject:'+str(test_sub))
        for metrics in test_results:
            for metric_name, value in metrics.items():
                f.write(','+str(value))
        f.write('\n')
        f.close()
    
    #test_model = FastAlertNet.load_from_checkpoint("/home/jie/Program/872/FastAlertNet.ckpt", map_location=torch.device('cuda')).float().to('cuda')


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool1d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[42mFLOPs: 101.680K, Parameters: 546.000B[0m
IntraSub
IntraSub
IntraSub


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name            | Type              | Params | Mode 
--------------------------------------------------------------
0 | pointwise_conv  | Conv1d            | 144    | train
1 | depthwise_conv  | Conv1d            | 64     | train
2 | bn2             | BatchNorm1d       | 32     | train
3 | relu2           | ReLU              | 0      | train
4 | global_avg_pool | AdaptiveAvgPool1d | 0      | train
5 | fc              | Sequential        | 306    | train
--------------------------------------------------------------
546       Trainable params
0         Non-trainable params
546       Total params
0.002     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _releaseLock at 0x7b74ad056a20>
Traceback (most recent call last):
  File "/home/jie/anaconda3/lib/python3.12/logging/__init__.py", line 243, in _releaseLock
    def _releaseLock():
    
KeyboardInterrupt: 


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/jie/anaconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=35` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

IntraSub
IntraSub
IntraSub


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
/home/jie/anaconda3/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory logs/FastAlertNet_logs/lightning_logs/version_9/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name            | Type              | Params | Mode 
--------------------------------------------------------------
0 | pointwise_conv  | Conv1d            | 144    | train
1 | depthwise_conv  | Conv1d            | 64     | train
2 | bn2             | BatchNorm1d       | 32     | train
3 | relu2           | ReLU              | 0      | train
4 | global_avg_pool | AdaptiveAvgPool1d | 0      | train
5 | fc              | Sequential        | 306    | train
--------------------------------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [3]:
f = open('metric.txt', 'r')
avg_acc = 0
avg_pre = 0
avg_recall = 0
avg_f1 = 0
for i in f.readlines():
    acc,pre,recall,f1 = i.split(',')[1:]
    avg_acc += float(acc)
    avg_pre += float(pre)
    avg_recall += float(recall)
    avg_f1 += float(f1)
    
avg_acc /= config['subjects_num']
avg_pre /= config['subjects_num']
avg_recall /= config['subjects_num']
avg_f1 /= config['subjects_num']
print(f"&{avg_acc*100:.2f}",f"&{avg_pre*100:.2f}",f"&{avg_recall*100:.2f}",f"&{avg_f1*100:.2f}")

&74.95 &97.02 &74.95 &80.05
