# TEST

In [1]:
#%pip install openneuro-py mne mne_bids naplib autoreject PyWavelets --quiet

In [1]:
import os
import numpy as np
import mne
from mne.datasets import sample
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from openneuro import download
from autoreject import get_rejection_threshold

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Загрузка данных
# download(dataset='ds002778', target_dir='./ds002778')

In [2]:
# Расположение данных
dataset_path = '/Users/evakhromeeva/mne_data/ds002778'

In [3]:
# Параметры датасета
dataset = "ds002778"
subjects_pd = ["sub-pd3", "sub-pd5", "sub-pd6", "sub-pd6", "sub-pd9","sub-pd11", "sub-pd12", "sub-pd13", "sub-pd14", "sub-pd16", "sub-pd17", "sub-pd19", "sub-pd22", "sub-pd23", "sub-pd26", "sub-pd28"]# subject with PD
subjects_hc = ["sub-hc1", "sub-hc2", "sub-hc4", "sub-hc7", "sub-hc8", "sub-hc10", "sub-hc18", "sub-hc20", "sub-hc21", "sub-hc24", "sub-hc25", "sub-hc29", "sub-hc30", "sub-hc31", "sub-hc32", "sub-hc33"] # healthy subject

In [4]:
# путь до файла с ЭЭГ выбранного пациента (например, sub-pd6)
def get_eeg_data_path(sub):
    return f'{dataset_path}/{sub}/ses-off/eeg/{sub}_ses-off_task-rest_eeg.bdf' if sub.startswith('sub-pd') else f'{dataset_path}/{sub}/ses-hc/eeg/{sub}_ses-hc_task-rest_eeg.bdf'

In [5]:
# Загрузка и предобработка данных
def load_and_preprocess_data(data_list, is_healthy=True, duration=2.0):
    raw_files = [get_eeg_data_path(f) for f in data_list]
    epochs_list = []
    labels = []

    for raw_file in raw_files:
        raw = mne.io.read_raw_bdf(raw_file, preload=True)
        raw.drop_channels(['EXG1', 'EXG2', 'EXG3', 'EXG4','EXG5', 'EXG6', 'EXG7', 'EXG8', 'Status']) #drop extra channels
        raw.set_eeg_reference(ref_channels='average')
        raw.filter(0.5, None, fir_design='firwin',phase='zero-double') #remove drifts
        event_id = {'healthy': 0, 'parkinson': 1}
        epochs = mne.make_fixed_length_epochs(raw, duration=duration, preload=False, proj=True)
        reject = get_rejection_threshold(epochs)
        epochs.drop_bad(reject=reject)
        epochs_list.append(epochs)
        labels.extend([event_id['healthy'] if is_healthy else event_id['parkinson']] * len(epochs))

    epochs_data = np.concatenate([e.get_data() for e in epochs_list])
    labels = np.array(labels)
    return epochs_data, labels

In [6]:
# Преобразование данных для PyTorch
def prepare_torch_data(X, y):
    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.long)
    dataset = TensorDataset(X, y)
    return dataset

In [19]:
import torch
import torch.nn as nn
import torch.fft

class ImprovedEEGNetWithFFT(nn.Module):
    def __init__(self, n_classes=2):
        super(ImprovedEEGNetWithFFT, self).__init__()
        
        # Сверточные слои для временных данных
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(1, 64), padding=(0, 32))
        self.batchnorm1 = nn.BatchNorm2d(32)
        
        # Сверточные слои для частотных данных
        self.conv2_freq = nn.Conv2d(1, 32, kernel_size=(1, 32), padding=(0, 16))
        self.batchnorm2_freq = nn.BatchNorm2d(32)
        
        # Объединение признаков
        self.conv3 = nn.Conv2d(64, 128, kernel_size=(1, 32), padding=(0, 16))
        self.batchnorm3 = nn.BatchNorm2d(128)
        
        # Пулинг
        self.pooling = nn.AvgPool2d(kernel_size=(1, 8))
        
        # Dropout
        self.dropout = nn.Dropout(0.5)
        
        # Полносвязные слои
        self.fc1 = nn.Linear(128 * 1 * 16, 128)  # вот тут косяк! Здесь нужно подсунуть тензор torch.Size([32, 524288]) (x)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        # Временные признаки
        x_time = torch.relu(self.conv1(x))
        x_time = self.batchnorm1(x_time)
        
        # Частотные признаки (FFT)
        x_freq = torch.fft.fft(x, dim=-1).abs()  # Вычисляем амплитудный спектр

        x_freq = torch.relu(self.conv2_freq(x_freq))
        x_freq = self.batchnorm2_freq(x_freq)
        
        # Объединение временных и частотных признаков
        x = torch.cat((x_time, x_freq), dim=1)
        
        # Дальнейшая обработка
        x = torch.relu(self.conv3(x))
        x = self.batchnorm3(x)
        x = self.pooling(x)
        x = self.dropout(x)
        
        # Выравнивание данных для полносвязного слоя
        x = x.view(x.size(0), -1)
        
        # Полносвязные слои
        print(x.shape)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [20]:
# Инициализация модели, функции потерь и оптимизатора
model = ImprovedEEGNetWithFFT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [21]:
# Обучение модели
train_model(model, train_loader, criterion, optimizer, num_epochs=10)

torch.Size([32, 524288])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x524288 and 2048x128)

In [None]:
# Оценка модели
model.eval()
y_pred = []
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        y_pred.extend(predicted.numpy())

print(classification_report(y_test, y_pred))
print(f'Accuracy: {accuracy_score(y_test, y_pred)}')