In [1]:
import os

# 设置文件夹路径和类别名称
folders = [r'C:\Users\bugs_\PycharmProjects\eegProject\data\Test_EEG\HC',
           r'C:\Users\bugs_\PycharmProjects\eegProject\data\Test_EEG\MDD',
           r'C:\Users\bugs_\PycharmProjects\eegProject\data\Test_EEG\BD']
class_names = ['HC', 'MDD', 'BD']

# 创建存储数据集的字典
data_file = {'filename': []}

# 遍历每个类别的文件夹
for folder, class_name in zip(folders, class_names):
    # 获取文件夹中的Clean.mat文件列表
    file_list = os.listdir(folder)
    file_list = [os.path.join(folder, file) for file in file_list if file.endswith('Clean.mat')]

    # 将数据加入到对应的数据集中
    data_file['filename'].extend([(file, class_name) for file in file_list])


In [2]:
import numpy as np
import scipy.io as sio

label_dic = {'HC': 0, 'MDD': 1, "BD": 2}
target_shape = (284, 16, 24000)
X = np.empty(target_shape)
y = {'label': np.empty(284, dtype=int)}

num_merged = 0

for sub, label in data_file['filename']:  # 被试循环

    data = sio.loadmat(sub)
    sample = data['EEG_ECClean']
    eeg_data = sample["data"][0][0][:, 0:24000]
    X[num_merged] = eeg_data
    y['label'][num_merged] = label_dic[label]
    num_merged += 1

In [3]:
from torcheeg import transforms
from torcheeg.datasets import NumpyDataset

dataset = NumpyDataset(X=X,
                       y=y,
                       online_transform=transforms.Compose([
                           transforms.MinMaxNormalize(axis=-1),
                           transforms.ToTensor()
                       ]),
                       label_transform=transforms.Select('label'),
                       num_worker=4
                       )
print(dataset[0])

The target folder already exists, if you need to regenerate the database IO, please delete the path ./io/numpy.
(tensor([[-1.4405, -0.4301,  1.5837,  ..., -5.1220, -4.4198, -3.1613],
        [-2.5493, -3.8331, -3.7955,  ..., -3.7847, -3.4779, -3.4689],
        [-0.9064, -0.0218,  0.6677,  ..., -0.7242, -1.5270, -1.8247],
        ...,
        [-0.6049, -0.0560,  0.7506,  ...,  1.4891,  1.9511,  1.6877],
        [ 0.4088, -1.1082, -1.8689,  ..., -0.8147, -1.5469, -1.1633],
        [ 2.4367,  3.2410,  2.2675,  ...,  5.2765,  5.4156,  4.2853]]), 0)


In [22]:
dataset[0][1]

0

In [4]:
from torcheeg.model_selection import KFold

k_fold = KFold(n_splits=5,
               split_path=f'./tmp_out/split_5',
               shuffle=True)

In [14]:
import torch
from torch import nn

device = "cuda" if torch.cuda.is_available() else "cpu"
loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)
batch_size = 8

In [15]:
from torch.utils.tensorboard import writer

def train(dataloader, model, loss_fn, optimizer):
    global total_train_step
    size = len(dataloader.dataset)
    model.train()
    train_correct = 0

    for batch_idx, batch in enumerate(dataloader):
        X = batch[0].to(device)
        y = batch[1].to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1

        # if batch_idx % 100 == 0:
        loss, current = loss.item(), batch_idx * len(X)
        print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        writer.add_scalar("train_loss", loss, total_train_step)
        train_correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    train_correct /= size
    writer.add_scalar("train_auc", train_correct, total_train_step)
    print(f"Train Error: \n Accuracy: {(100 * train_correct):>0.1f}% ")

def valid(dataloader, model, loss_fn):
    global total_val_step
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            X = batch[0].to(device)
            y = batch[1].to(device)

            pred = model(X)
            val_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    val_loss /= num_batches
    correct /= size
    # writer.add_scalar("test_avg_loss", val_loss, total_val_step)
    writer.add_scalar("test_auc", correct, total_val_step)
    total_val_step += 1
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")

In [17]:
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader
from torcheeg.models import GRU

writer = SummaryWriter(r".\log\log_LSTM64_5k_shuffle111_batch8_epoch200_lr1e-4")
total_train_step = 0
total_val_step = 0

for i, (train_dataset, val_dataset) in enumerate(k_fold.split(dataset)):
    model = GRU(num_electrodes=16, hid_channels=64, num_classes=3).to(device)


    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    epochs = 200
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_loader, model, loss_fn, optimizer)
        valid(val_loader, model, loss_fn)
    print("Done!")
writer.close()

Epoch 1
-------------------------------
loss: 1.201394  [    0/  227]
loss: 1.180781  [    8/  227]
loss: 1.117204  [   16/  227]
loss: 1.097910  [   24/  227]
loss: 1.273638  [   32/  227]
loss: 1.106595  [   40/  227]
loss: 1.083001  [   48/  227]
loss: 1.054256  [   56/  227]
loss: 1.073216  [   64/  227]
loss: 1.099655  [   72/  227]
loss: 1.135724  [   80/  227]
loss: 1.177622  [   88/  227]
loss: 1.201264  [   96/  227]
loss: 1.188235  [  104/  227]
loss: 1.093947  [  112/  227]
loss: 1.059066  [  120/  227]
loss: 0.981992  [  128/  227]
loss: 1.135923  [  136/  227]
loss: 1.108213  [  144/  227]
loss: 1.101047  [  152/  227]
loss: 1.226500  [  160/  227]
loss: 1.171014  [  168/  227]
loss: 1.023875  [  176/  227]
loss: 1.171660  [  184/  227]
loss: 1.065082  [  192/  227]
loss: 1.019707  [  200/  227]
loss: 1.063549  [  208/  227]
loss: 1.003628  [  216/  227]
loss: 1.115245  [   84/  227]
Train Error: 
 Accuracy: 32.2% 
Test Error: 
 Accuracy: 33.3%, Avg loss: 1.131671 

Epoch 