In [13]:
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim

from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix

In [14]:
# 获取当前日期
current_date = datetime.now().strftime("%Y-%m-%d")

outdir = 'outdir'
if not os.path.exists(outdir):
    os.makedirs(outdir)

# 构建导出的pth文件名以及ONNX文件名
pth_file_path = os.path.join(outdir, f'eeg_depression-best_model_{current_date}.pth')
onnx_file_path = os.path.join(outdir, f"eeg_depression-best_model_{current_date}.onnx")

In [15]:
class CustomDataset(Dataset):
    def __init__(self, positive_file, negative_file):
        # Load positive samples
        with open(positive_file, 'rb') as f:
            self.positive_samples = pickle.load(f)
        
        # Load negative samples
        with open(negative_file, 'rb') as f:
            self.negative_samples = pickle.load(f)

        # Ensure both sets have the same length
        self.length = min(len(self.positive_samples), len(self.negative_samples))

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Concatenate positive and negative samples to form a batch
        subject, positive_feature, positive_label = self.positive_samples[idx % len(self.positive_samples)]
        subject, negative_feature, negative_label = self.negative_samples[idx % len(self.negative_samples)]
        return torch.tensor(positive_feature), torch.tensor(negative_feature)

class SimpleDNN(nn.Module):
    def __init__(self):
        super(SimpleDNN, self).__init__()
        self.fc1 = nn.Linear(128, 64)  # Input size 100, output size 64
        self.dropout1 = nn.Dropout(p=0.3)  # Dropout with a probability of 0.5
        self.fc2 = nn.Linear(64, 32)   # Input size 64, output size 32
        self.dropout2 = nn.Dropout(p=0.3)  # Dropout with a probability of 0.5
        self.fc3 = nn.Linear(32, 1)    # Input size 32, output size 1

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        x = torch.sigmoid(self.fc3(x))
        return x

# Define paths to your data files
train_positive_file = './samples/train_positive_samples.pkl'
train_negative_file = './samples/train_negative_samples.pkl'
test_positive_file = './samples/test_positive_samples.pkl'
test_negative_file = './samples/test_negative_samples.pkl'

# Create datasets
train_dataset = CustomDataset(train_positive_file, train_negative_file)
test_dataset = CustomDataset(test_positive_file, test_negative_file)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize model, loss function, and optimizer
model = SimpleDNN()
criterion = nn.BCELoss()  # Binary Cross Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=0.01)  # Adam optimizer

best_accuracy = 0.0

# Training loop
for epoch in range(30):  # Train for 10 epochs
    for batch_idx, (positive_batch, negative_batch) in enumerate(train_loader):
        # Concatenate positive and negative batches
        inputs = torch.cat((positive_batch, negative_batch), dim=0)
        labels = torch.cat((torch.ones(positive_batch.size(0), 1), torch.zeros(negative_batch.size(0), 1)), dim=0)

        optimizer.zero_grad()  # Zero the gradients
        outputs = model(inputs.float())  # Forward pass
        loss = criterion(outputs, labels.float())  # Calculate loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

    # Testing loop
    model.eval()  # Set the model to evaluation mode
    test_correct = 0
    total_samples = 0
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for positive_batch, negative_batch in test_loader:
            inputs = torch.cat((positive_batch, negative_batch), dim=0)
            labels = torch.cat((torch.ones(positive_batch.size(0), 1), torch.zeros(negative_batch.size(0), 1)), dim=0)
            
            outputs = model(inputs.float())
            predicted = (outputs > 0.5).float()  # Threshold at 0.5
            total_samples += labels.size(0)
            test_correct += (predicted == labels).sum().item()

            all_labels.extend(labels.numpy())
            all_predictions.extend(predicted.numpy())

    accuracy = test_correct / total_samples
    print(f'Epoch: {epoch} Train Loss: {loss} Test Accuracy: {accuracy}')

    # Save the best model
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), pth_file_path)

        # Export the best model to ONNX
        dummy_input = torch.randn(1, 128)
        torch.onnx.export(model, dummy_input, onnx_file_path, input_names=['input'], output_names=['output'])

print(f"best_accuracy: {best_accuracy}")

# Load the best model
model.load_state_dict(torch.load(pth_file_path))

# Calculate confusion matrix
conf_matrix = confusion_matrix(all_labels, all_predictions)
print("Confusion Matrix:")
print(conf_matrix)




Epoch: 0 Train Loss: 0.5409842133522034 Test Accuracy: 0.6988108812510181
Epoch: 1 Train Loss: 0.49743080139160156 Test Accuracy: 0.7063854047890535
Epoch: 2 Train Loss: 0.3357866108417511 Test Accuracy: 0.7057338328718032
Epoch: 3 Train Loss: 0.3512541949748993 Test Accuracy: 0.7151001791822772
Epoch: 4 Train Loss: 0.5077861547470093 Test Accuracy: 0.7392897866101971
Epoch: 5 Train Loss: 0.3912697732448578 Test Accuracy: 0.7300048867893794
Epoch: 6 Train Loss: 0.46829667687416077 Test Accuracy: 0.7259325623065646
Epoch: 7 Train Loss: 0.3620418310165405 Test Accuracy: 0.7341586577618505
Epoch: 8 Train Loss: 0.5151393413543701 Test Accuracy: 0.7269099201824402
Epoch: 9 Train Loss: 0.4151233732700348 Test Accuracy: 0.7245479719824076
Epoch: 10 Train Loss: 0.40602409839630127 Test Accuracy: 0.7346473366997882
Epoch: 11 Train Loss: 0.49229711294174194 Test Accuracy: 0.7300863332790357
Epoch: 12 Train Loss: 0.4111693501472473 Test Accuracy: 0.7269099201824402
Epoch: 13 Train Loss: 0.4453388

In [16]:
# output the predictions in each test subjects

In [17]:
import glob
import numpy as np
import onnxruntime as ort

from collections import defaultdict

In [18]:
# 加载 ONNX 模型
# onnx_model_path = './outdir/eeg_depression-best_model_2024-06-27.onnx'
# ort_session = ort.InferenceSession(onnx_model_path)

def get_onnxfile():
    outdir = './outdir'
    onnx_files = glob.glob(os.path.join(outdir, '*.onnx'))
    # 如果找到的文件不为空
    if onnx_files:
        # 按文件名排序
        onnx_files.sort()

        # 选择最后一个文件
        last_onnx_file = onnx_files[-1]
        return last_onnx_file

onnx_model_path = get_onnxfile()
ort_session = ort.InferenceSession(onnx_model_path)

# 预测函数
def predict(features):
    inputs = {ort_session.get_inputs()[0].name: features}
    outputs = ort_session.run(None, inputs)
    return outputs[0]


for filePath in [test_negative_file, test_positive_file]:
    if filePath == test_negative_file:
        print("\nTest for Negative subjects:")
    else:
        print("\nTest for Positive subjects:")

    # 加载测试数据集
    with open(filePath, 'rb') as f:
        test_data = pickle.load(f)

    # 计算每个 subject 的准确率
    subject_results = defaultdict(list)

    for subject, features, label in test_data:
        features = np.array(features, dtype=np.float32)  # 确保 features 的数据类型与模型输入匹配
        features = features.reshape(1, -1)  # 调整形状以匹配模型输入
        prediction = predict(features)
        predicted_label = (prediction >= 0.5).astype(int)  # 使用阈值 0.5 将预测概率转换为类标签
        subject_results[subject].append((predicted_label, label))

    subject_accuracies = {}
    for subject, results in subject_results.items():
        correct = sum(1 for pred, label in results if pred == label)
        accuracy = correct / len(results)
        subject_accuracies[subject] = accuracy

    # 输出每个 subject 的准确度
    for subject, accuracy in subject_accuracies.items():
        print(f'Subject: {subject}, Accuracy: {accuracy:.2f}')


Test for Negative subjects:
Subject: A08, Accuracy: 0.58
Subject: A02, Accuracy: 0.73
Subject: A22, Accuracy: 0.98
Subject: A06, Accuracy: 0.97
Subject: A10, Accuracy: 0.63
Subject: A24, Accuracy: 0.47
Subject: A30, Accuracy: 0.92
Subject: A17, Accuracy: 0.96
Subject: A28, Accuracy: 0.96

Test for Positive subjects:
Subject: 20240619_2A50, Accuracy: 0.95
Subject: 20240710_2A28, Accuracy: 0.38
Subject: 20240717_2A12, Accuracy: 0.61
Subject: 20240626_2A26, Accuracy: 0.90
Subject: 20240619_2A15, Accuracy: 0.85
Subject: 20240710_2A30, Accuracy: 0.76
Subject: 20240626_2A13, Accuracy: 0.19
Subject: 20240710_2A05, Accuracy: 0.97
Subject: 20240710_2A31, Accuracy: 0.77
