In [None]:
import os
from os.path import join as pjoin
import h5py
import numpy as np
import scipy.io as sio
from tqdm import tqdm
import mne
from mne.io import RawArray

import argparse
import json
import logging
import sys
from os import makedirs
from os.path import join as pjoin
from shutil import copy2, move

import h5py
import numpy as np
import torch
import torch.nn.functional as F
from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.models.deep4 import Deep4Net
from braindecode.torch_ext.optimizers import AdamW
from braindecode.torch_ext.util import set_random_seeds
from sklearn.model_selection import KFold

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


In [None]:
class EEGDataProcessor:
    def __init__(self, sampling_rate=250):
        self.sfreq = sampling_rate
        self.ch_types = ['eeg'] * 32
        self.filter_bands = [
            (4, 8), (8, 12), (12, 16), (16, 20), 
            (20, 24), (24, 28), (28, 32), (32, 36), (36, 40)
        ]

    def create_raw_from_data(self, data, ch_names):
        info = mne.create_info(
            ch_names=ch_names,
            sfreq=self.sfreq,
            ch_types=self.ch_types[:len(ch_names)]
        )
        return mne.io.RawArray(data, info)

    def apply_filter_bank(self, raw):
        filtered_data = []
        for low_freq, high_freq in self.filter_bands:
            raw_filtered = raw.copy()
            raw_filtered.filter(low_freq, high_freq, method='iir')
            filtered_data.append(raw_filtered.get_data())
        return np.stack(filtered_data, axis=-1)

In [None]:
def get_shu_data(subj, processor):
    file_path = os.getcwd()
    data_list = []
    labels = np.empty(0)
    
    # 上海大学通道名
    ch_names = ["Fp1", "Fp2", "Fz", "F3", "F4", "F7", "F8", "FC1", "FC2", "FC5",
                "FC6", "Cz", "C3", "C4", "T3", "T4", "A1", "A2", "CP1", "CP2",
                "CP5", "CP6", "Pz", "P3", "P4", "T5", "T6", "PO3", "PO4", "Oz",
                "O1", "O2"]
    
    # 删除的通道
    channels_to_delete = [16, 17, 27, 28]
    keep_channels = np.ones(len(ch_names), dtype=bool)
    keep_channels[channels_to_delete] = False
    ch_names = [ch for i, ch in enumerate(ch_names) if keep_channels[i]]
    
    for session in range(1, 6):
        da = sio.loadmat(pjoin(file_path, 'SHU_Dataset', 
                              f'sub-{str(subj).zfill(3)}_ses-{str(session).zfill(2)}_task_motorimagery_eeg.mat'))
        print(da['data'].shape)
        # 对每个样本
        for trial_idx in range(da['data'].shape[0]):
            trial_data = da['data'][trial_idx]  # (channels, time)
            trial_data = trial_data[keep_channels]
            print(trial_data.shape)
            # 滤波
            raw = processor.create_raw_from_data(trial_data, ch_names)
            filtered_data = processor.apply_filter_bank(raw)
            data_list.append(filtered_data)
            
        labels = np.hstack((labels, np.ravel(da['labels'])))
    
    X = np.stack(data_list, axis=0)
    Y = (np.ravel(labels) - 1).astype(np.int64)
    return X, Y

In [None]:
def get_stroke_data(subj, processor):
    file_path = os.getcwd()
    
    # 中风通道名
    ch_names = ["Fp1", "Fp2", "Fz", "F3", "F4", "F7", "F8", "FCz", "FC3", "FC4", "FT7",
                "FT8", "Cz", "C3", "C4", "T3", "T4", "CPz", "CP3", "CP4", "TP7", "TP8",
                "Pz", "P3", "P4", "T5", "T6", "Oz", "O1", "O2", "HEOL", "VEOR", "Tng"]
    
    # 删除的通道
    channels_to_delete = [7, 17, 30, 31, 32]
    keep_channels = np.ones(len(ch_names), dtype=bool)
    keep_channels[channels_to_delete] = False
    ch_names = [ch for i, ch in enumerate(ch_names) if keep_channels[i]]
    
    da = sio.loadmat(pjoin(file_path, 'Stroke_Dataset', 
                          f'sub-{str(subj).zfill(2)}',
                          f'sub-{str(subj).zfill(2)}_task-motor-imagery_eeg.mat'))
    
    data = da['eeg']['rawdata'][0, 0]
    labels = np.ravel(da['eeg']['label'][0, 0])
    
    # 采样率为500Hz，只要一半250Hz
    data = data[:, :, 1000:3000:2]  # (trials, channels, time)
    data_list = []
    
    print(data.shape)
    
    for trial_idx in range(data.shape[0]):
        trial_data = data[trial_idx]  # (channels, time)
        trial_data = trial_data[keep_channels]
        print(trial_data.shape)
        # 滤波
        raw = processor.create_raw_from_data(trial_data, ch_names)
        filtered_data = processor.apply_filter_bank(raw)
        data_list.append(filtered_data)
    
    X = np.stack(data_list, axis=0)
    Y = (labels - 1).astype(np.int64)
    return X, Y


In [None]:
# 初始化滤波器
processor = EEGDataProcessor(sampling_rate=250)

In [None]:
# 处理上海大学数据
print("Processing SHU Dataset...")
with h5py.File('data/SHU_data.h5', 'w') as f:
    for subj in tqdm(range(1, 26)):
        X, Y = get_shu_data(subj, processor)
        
        f.create_dataset('s' + str(subj) + '/X', data=X)
        f.create_dataset('s' + str(subj) + '/Y', data=Y)


In [None]:
# 处理中风数据
print("\nProcessing Stroke Dataset...")
with h5py.File('data/Stroke_data.h5', 'w') as f:
    for subj in tqdm(range(1, 51)):
        X, Y = get_stroke_data(subj, processor)

        f.create_dataset('s' + str(subj) + '/X', data=X)
        f.create_dataset('s' + str(subj) + '/Y', data=Y)

In [None]:
# 标记完成
print('done')

In [None]:
# 定义FBCNet（从上海大学获取）
"""
All network architectures: FBCNet, EEGNet, DeepConvNet
@author: Ravikiran Mane
"""
import torch
import torch.nn as nn
import sys
current_module = sys.modules[__name__]

debug = False

class Conv2dWithConstraint(nn.Conv2d):
    def __init__(self, *args, doWeightNorm = True, max_norm=1, **kwargs):
        self.max_norm = max_norm
        self.doWeightNorm = doWeightNorm
        super(Conv2dWithConstraint, self).__init__(*args, **kwargs)

    def forward(self, x):
        if self.doWeightNorm:
            self.weight.data = torch.renorm(
                self.weight.data, p=2, dim=0, maxnorm=self.max_norm
            )
        return super(Conv2dWithConstraint, self).forward(x)

class LinearWithConstraint(nn.Linear):
    def __init__(self, *args, doWeightNorm = True, max_norm=1, **kwargs):
        self.max_norm = max_norm
        self.doWeightNorm = doWeightNorm
        super(LinearWithConstraint, self).__init__(*args, **kwargs)

    def forward(self, x):
        if self.doWeightNorm:
            self.weight.data = torch.renorm(
                self.weight.data, p=2, dim=0, maxnorm=self.max_norm
            )
        return super(LinearWithConstraint, self).forward(x)

class VarLayer(nn.Module):
    '''
    The variance layer: calculates the variance of the data along given 'dim'
    '''
    def __init__(self, dim):
        super(VarLayer, self).__init__()
        self.dim = dim

    def forward(self, x):
        return x.var(dim = self.dim, keepdim= True)

class FBCNet(nn.Module):
    '''
        Just a FBCSP like structure : Channel-wise convolution and then variance along the time axis
        The data input is in a form of batch x 1 x chan x time x filterBand
    '''

    def SCB(self, m, nChan, nBands, doWeightNorm=True, *args, **kwargs):
        '''
        The spatial convolution block
        m : number of spatial filters.
        nBands: number of bands in the data
        '''
        return nn.Sequential(
            Conv2dWithConstraint(nBands, m * nBands, (nChan, 1), groups=nBands,
                                 max_norm=2, doWeightNorm=doWeightNorm, padding=0),
            nn.BatchNorm2d(m * nBands),
            nn.ELU()
        )

    def LastBlock(self, inF, outF, doWeightNorm=True, *args, **kwargs):
        return nn.Sequential(
            LinearWithConstraint(inF, outF, max_norm=0.5, doWeightNorm=doWeightNorm, *args, **kwargs),
            nn.LogSoftmax(dim=1))

    def __init__(self, nChan, nTime, nClass=2, nBands=9, m=4,
                 temporalLayer='VarLayer', doWeightNorm=True, *args, **kwargs):
        super(FBCNet, self).__init__()

        self.nBands = nBands
        self.m = m

        # create all the parallel SCBc
        self.scb = self.SCB(m, nChan, self.nBands, doWeightNorm=doWeightNorm)

        # Formulate the temporal aggregator
        self.temporalLayer = current_module.__dict__[temporalLayer](dim=3)

        # The final fully connected layer
        self.lastLayer = self.LastBlock(self.m * self.nBands, nClass, doWeightNorm=doWeightNorm)

    def forward(self, x):
        # 根据我的数据处理一下形状
        # print(x.shape)
        x = x.unsqueeze(1)
        # print(x.shape)
        x = torch.squeeze(x.permute((0, 4, 2, 3, 1)), dim=4)
        
        x = self.scb(x)
        x = self.temporalLayer(x)
        x = torch.flatten(x, start_dim=1)
        x = self.lastLayer(x)
        return x


In [None]:
# 设置日志格式
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                    level=logging.INFO, stream=sys.stdout)

# 配置参数
datapath = os.path.join(os.getcwd(), 'data','SHU_data.h5')  # 数据文件路径
strokepath = os.path.join(os.getcwd(), 'data','Stroke_data.h5') # 中风患者数据文件路径
outpath = os.path.join(os.getcwd(), 'results')         # 结果输出路径
if torch.cuda.is_available():
    gpu_device = 0
    device = 'cuda'
    torch.cuda.set_device(gpu_device)
else:
    gpu_device = 'cpu'
    device = 'cpu'

# 设置设备
torch.cuda.set_device(gpu_device)

# 设置随机种子
set_random_seeds(seed=20200205, cuda=True)

# 训练参数
BATCH_SIZE = 32
TRAIN_EPOCH = 350

# 获取单个被试数据
def get_data(dfile, subj):
    dpath = '/s' + str(subj)
    X = dfile[dpath]['X']
    Y = dfile[dpath]['Y']
    return X, Y

# 获取多个被试数据
def get_multi_data(dfile, subjs):
    Xs = []
    Ys = []
    for s in subjs:
        x, y = get_data(dfile, s)
        Xs.append(x[:])
        Ys.append(y[:])
    X = np.concatenate(Xs, axis=0)
    Y = np.concatenate(Ys, axis=0)
    return X, Y

# 训练base model
def train_model(fold,lr,X_health, Y_health):
    outpath = pjoin('results', 'S'+str(fold+1).zfill(2))
    makedirs(outpath, exist_ok=True)
    
    # 中风被试编号
    all_subjs = list(range(1, 51))
    subjs = list(range(1, 51))
    
    # 确定被试和交叉验证集
    test_subj = subjs[fold]
    cv_set = np.array(all_subjs[fold+1:] + all_subjs[:fold])
    ###################################################
    # cv_set = np.array(subjs[-1:] + subjs[:1])
    

    # 创建6折交叉验证
    kf = KFold(n_splits=6)
    ###################################################
    # kf = KFold(n_splits=2)
    
    # 打开中风数据文件
    print('start base model training')
    with h5py.File(strokepath, 'r') as sfile:
        cv_loss = []
        
        # 交叉验证
        for cv_index, (train_index, test_index) in enumerate(kf.split(cv_set)):
            # 准备训练、验证和测试集
            train_subjs = cv_set[train_index]
            valid_subjs = cv_set[test_index]
            
            X_train, Y_train = get_multi_data(sfile, train_subjs)
            X_val, Y_val = get_multi_data(sfile, valid_subjs)
            X_test, Y_test = get_data(sfile, test_subj)
            
            # 扩充训练集
            X_train = np.concatenate((X_train, X_health), axis=0)
            Y_train = np.concatenate((Y_train, Y_health), axis=0)
    
            # 扩充验证集
            X_val = np.concatenate((X_val, X_health), axis=0)
            Y_val = np.concatenate((Y_val, Y_health), axis=0)
            
            # 转换为张量
            X_train = torch.tensor(X_train, dtype=torch.float32)
            Y_train = torch.tensor(Y_train, dtype=torch.long)
            X_val = torch.tensor(X_val, dtype=torch.float32)
            Y_val = torch.tensor(Y_val, dtype=torch.long)
            X_test = torch.tensor(X_test[:], dtype=torch.float32)
            Y_test = torch.tensor(Y_test[:], dtype=torch.long)
            
            # 将数据移到GPU
            X_train = X_train.to(device)
            Y_train = Y_train.to(device)
            X_val = X_val.to(device)
            Y_val = Y_val.to(device)
            X_test = X_test.to(device)
            Y_test = Y_test.to(device)
            
            # 创建DataLoader
            train_dataset = TensorDataset(X_train, Y_train)
            valid_dataset = TensorDataset(X_val, Y_val)
            test_dataset = TensorDataset(X_test, Y_test)
            
            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
            valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            # print('len(test_loader)',len(test_loader))
            
            # 模型参数
            n_classes = 2
            in_chans = X_train.shape[1]
            # print(in_chans)
            
            # 创建模型
            model = FBCNet(
                nChan=in_chans,
                nClass=n_classes,
                nTime=X_train.shape[2],
                nBands=9
            ).to(device)
            
            # 训练参数
            optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.5*0.0001)
            scheduler = CosineAnnealingLR(optimizer, T_max=TRAIN_EPOCH)
            criterion = F.nll_loss
            
            # 训练
            best_val_loss = float('inf')
            best_val_acc = float('inf')
            for epoch in range(TRAIN_EPOCH):
                model.train()
                running_loss = 0.0
                for inputs, labels in train_loader:
                    optimizer.zero_grad()
                    
                    # print(inputs.shape, labels.shape)
                    
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    
                scheduler.step()
                
                # 验证过程
                model.eval()
                
                train_acc = 0.0
                with torch.no_grad():
                    for inputs, labels in train_loader:
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = model(inputs)
                        _, predicted = torch.max(outputs, 1)
                        train_acc += (predicted == labels).float().sum().item()

                        
                
                val_loss = 0.0
                val_acc = 0.0
                with torch.no_grad():
                    for inputs, labels in valid_loader:
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        val_loss += loss.item()
                        _, predicted = torch.max(outputs, 1)
                        val_acc += (predicted == labels).float().sum().item()

                # 保存最佳模型
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_val_acc = val_acc
                    torch.save(model.state_dict(), pjoin(outpath, f'best_model_f{fold}_cv{cv_index}.pt'))

                # 打印训练和验证结果
                if epoch % 20 == 0:
                    print(f'Epoch {epoch+1}/{TRAIN_EPOCH}, Training Loss: {running_loss/len(train_loader)}, Training Acc: {train_acc/len(train_loader.dataset)}, Validation Loss: {val_loss/len(valid_loader)}, Validation Acc: {val_acc/len(valid_loader.dataset)}')
            
            # 评估测试集
            model = FBCNet(
                nChan=in_chans,
                nClass=n_classes,
                nTime=X_train.shape[2],
                nBands=9
            ).to(device)
            model.load_state_dict(torch.load(pjoin(outpath, f'best_model_f{fold}_cv{cv_index}.pt')))
            model.eval()
            test_loss = 0.0
            test_acc = 0.0
            with torch.no_grad():
                for inputs, labels in test_loader:
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    test_loss += loss.item()
                    _, predicted = torch.max(outputs, 1)
                    test_acc += (predicted == labels).float().sum().item()
                    
            
            with open(pjoin(outpath, f'test_base_s{test_subj}_f{fold}_cv{cv_index}.json'), 'w') as f:
                json.dump({'test_loss': test_loss/len(test_loader), 'test_acc': test_acc/len(test_loader.dataset)}, f)
            
            cv_loss.append(best_val_loss)

        # 选择最佳交叉验证模型
        best_cv = np.argmin(cv_loss)
        best_dir = pjoin(outpath, "best")
        os.makedirs(best_dir, exist_ok=True)
        
        # 记录折数
        with open(pjoin(best_dir, "fold_bestcv.txt"), 'w') as f:
            f.write(f"{fold}, {best_cv}\n")
        
        # 复制文件到best目录
        copy2(pjoin(outpath, f'best_model_f{fold}_cv{best_cv}.pt'),
              pjoin(best_dir, f'model_f{fold}.pt'))
        
        copy2(pjoin(outpath, f'test_base_s{test_subj}_f{fold}_cv{best_cv}.json'),
              pjoin(best_dir, f'test_base_s{test_subj}_f{fold}.json'))



In [None]:
health_set = np.array(list(range(1,6))) # 使用上海大学的前5个被试（否则服务器跑不动）
with h5py.File(datapath, 'r') as dfile:
        X_health, Y_health = get_multi_data(dfile, health_set)
# 设置折数（测试的被试数量）
folds = 50
###################################################
# folds = 1

lr = 1*0.001
for fold in range(folds):
    print(f"开始训练第 {fold+1} 折")
    train_model(fold,lr,X_health, Y_health)
    print(f"第 {fold+1} 折训练完成")

In [None]:
# 标记完成
print("base model training is done")

In [None]:
import os
import shutil

# 获取当前工作目录
cwd = os.getcwd()

# 创建目标文件夹路径
models_dir = os.path.join(cwd, 'models')
results_dir = os.path.join(cwd, 'results')
if not os.path.exists(models_dir):
    os.makedirs(models_dir)

# 循环处理S01到S50
for i in range(1, 51):
    folder_name = f'S{i:02}'
    source_folder = os.path.join(results_dir, folder_name, 'best')

    # 检查文件夹是否存在
    if os.path.exists(source_folder):
        # 列出所有model开头，.pt结尾的文件
        for file_name in os.listdir(source_folder):
            if file_name.startswith('model') and file_name.endswith('.pt'):
                source_file = os.path.join(source_folder, file_name)
                destination_file = os.path.join(models_dir, file_name)
                # 复制文件
                shutil.copy(source_file, destination_file)
                print(f'Copied {source_file} to {destination_file}')
    else:
        print(f'Folder {source_folder} does not exist')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import h5py
import numpy as np
from sklearn.model_selection import StratifiedKFold

datapath = os.path.join(os.getcwd(),'data','SHU_data.h5')
strokepath = os.path.join(os.getcwd(), 'data','Stroke_data.h5') # 中风患者数据文件路径
outpath = os.path.join(os.getcwd(), 'results')         # 结果输出路径
modelpath = os.path.join(os.getcwd(),'models')    # 模型路径
ADAPT_EPOCH = 50
subjs = list(range(1, 51))
################################################
# subjs = list(range(1, 2))
def adaptive_train(modelpath, subjs, scheme=4, lr=0.0005):
    results = []
    
    for fold, subj in enumerate(subjs):
        # 加载预训练检查点
        checkpoint = torch.load(
            os.path.join(modelpath, f'model_f{fold}.pt'), 
            map_location=f'cuda:{torch.cuda.current_device()}'
        )
    
        with h5py.File(strokepath, 'r') as sfile:
            # 获取数据
            X, Y = get_data(sfile, subj)
            
            # 初始化模型
            model = FBCNet(
                nChan=X.shape[1], 
                nTime=X.shape[2], 
                nClass=2, 
                nBands=9
            ).cuda()
            
            # 加载预训练权重
            model.load_state_dict(checkpoint)
            
            # 按自适应策略冻结/解冻层
            if scheme != 5:
                # 冻结所有层
                for param in model.parameters():
                    param.requires_grad = False
    
                if scheme in {1, 2, 3, 4}:
                    # 解冻分类器层
                    for param in model.lastLayer.parameters():
                        param.requires_grad = True
    
                if scheme in {2, 3, 4}:
                    # 解冻conv4层
                    for param in model.scb.parameters():
                        param.requires_grad = True
    
            # 优化器
            optimizer = torch.optim.AdamW(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=lr, weight_decay=0.5*0.001
            )
            criterion = nn.NLLLoss()
            
            # 训练和评估
            cv = StratifiedKFold(n_splits=10, shuffle=False)
            cv_acc = []
            
            for cv_fold, (train_idx, test_idx) in enumerate(cv.split(X, Y)):
                X_train, X_test = X[train_idx], X[test_idx]
                Y_train, Y_test = Y[train_idx], Y[test_idx]
                X_train_tensor = torch.FloatTensor(X_train).cuda()
                Y_train_tensor = torch.LongTensor(Y_train.astype(np.int64)).cuda()
                X_test_tensor = torch.FloatTensor(X_test).cuda()
                Y_test_tensor = torch.LongTensor(Y_test.astype(np.int64)).cuda()
                
                # 训练模型
                model.train()
                for epoch in range(ADAPT_EPOCH):
                    optimizer.zero_grad()
                    outputs = model(X_train_tensor)
                    loss = criterion(outputs, Y_train_tensor)
                    loss.backward()
                    optimizer.step()
                
                # 评估模型
                model.eval()
                with torch.no_grad():
                    test_outputs = model(X_test_tensor)
                    _, predicted = torch.max(test_outputs, 1)
                    accuracy = (predicted == Y_test_tensor).float().mean().item()
                    cv_acc.append(accuracy)
            avg_acc = np.mean(cv_acc)
            
            results.append([subj,avg_acc, cv_acc])
    
    return results

In [None]:
results = adaptive_train(modelpath, subjs, scheme=4)

In [None]:
# 标记完成
print('adapt model training is done')

In [None]:
print(results)

In [None]:
# 标记完成
print('all done')

In [2]:
sumup = 0
for i in range(len(results)):
    sumup += results[i][1]
    print(f'avg acc for subj{i+1}: ',results[i][1])
print('final acc for all subjs: ',sumup/len(results))

avg acc for subj1:  0.825
avg acc for subj2:  0.825
avg acc for subj3:  0.8
avg acc for subj4:  0.875
avg acc for subj5:  0.9
avg acc for subj6:  0.775
avg acc for subj7:  0.775
avg acc for subj8:  0.925
avg acc for subj9:  0.9
avg acc for subj10:  0.725
avg acc for subj11:  0.8
avg acc for subj12:  0.9
avg acc for subj13:  0.775
avg acc for subj14:  0.825
avg acc for subj15:  0.975
avg acc for subj16:  0.725
avg acc for subj17:  0.725
avg acc for subj18:  0.775
avg acc for subj19:  0.85
avg acc for subj20:  0.8
avg acc for subj21:  0.875
avg acc for subj22:  0.95
avg acc for subj23:  0.875
avg acc for subj24:  0.825
avg acc for subj25:  0.925
avg acc for subj26:  0.75
avg acc for subj27:  0.875
avg acc for subj28:  0.875
avg acc for subj29:  0.9
avg acc for subj30:  0.85
avg acc for subj31:  0.875
avg acc for subj32:  0.825
avg acc for subj33:  0.925
avg acc for subj34:  0.875
avg acc for subj35:  0.85
avg acc for subj36:  0.85
avg acc for subj37:  0.9
avg acc for subj38:  0.85
avg ac