In [1]:
import mne
mne.set_log_level('ERROR')  # выводить только ошибки
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import numpy as np
from braindecode.models import EEGNetv4 as EEGNet
import os
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
import torch.optim as optim
from tqdm import tqdm
import pandas as pd

### Preparation

In [2]:
def get_filtered_data(raw):
    
    raw_resampled = raw.copy().resample(256)
    
    # Убираем дрейф и ВЧ-шум → 0.5–45 Гц, мышечные артефакты (EMG), движение кожи
    raw_filtered = raw_resampled.copy().filter(l_freq=0.5, h_freq=45)
    
    # Удаление сетевой наводки (50 Гц)
    raw_filtered.notch_filter(freqs=[50])
    
    # Референцированиe
    # Пересчёт в среднее по всем каналам
    raw_filtered.set_eeg_reference('average')
    # дополнительно можно ICA -> find_bads_eog
    
    # Нормализация
    data = raw_filtered.get_data()  # [n_channels, n_times]
    # Z-score по каналам -убрать индивидуальные различия каналов по масштабу, сделать данные сопоставимыми
    data_zscored = (data - data.mean(axis=1, keepdims=True)) / data.std(axis=1, keepdims=True)
    return data_zscored

In [3]:
# for datasets with differences in channels
# globally used in 2 classes :(
needed_channels = [
     'C3-P3',
     'C4-P4',
     'CZ-PZ',
     'F3-C3',
     'F4-C4',
     'F7-T7',
     'F8-T8',
     'FP1-F3',
     'FP1-F7',
     'FP2-F4',
     'FP2-F8',
     'FT10-T8',
     'FT9-FT10',
     'FZ-CZ',
     'P3-O1',
     'P4-O2',
     'P7-O1',
     'P7-T7',
     'P8-O2',
     'T7-FT9',
     'T7-P7',
     'T8-P8-0']
def get_valid_files(dataset):

    res_set = []

    for f, label in dataset:
        try:
            raw = mne.io.read_raw_edf(f, preload=False, verbose=False)
            channels = raw.ch_names
            if all(ch in channels for ch in needed_channels):
                res_set.append((f, label))
        except Exception as e:
            print(f"{f.name}: {e}")
            
    return res_set

In [4]:
# Get list of files and labels
from os import listdir
from os.path import isfile, join
import mne
from pathlib import Path
training_set = [("patterns/normals/" + f, [1, 0]) for f in listdir("patterns/normals/") 
                if isfile(join("patterns/normals/", f)) and not f.startswith(('chb07', 'chb09', 'chb13', 'chb15'
                                                                          ,'chb06', 'chb18', 'chb23'))]
training_seiz = [("patterns/seizures/" + f, [0, 1]) for f in listdir("patterns/seizures/") 
                if isfile(join("patterns/seizures/", f)) and not f.startswith(('chb07', 'chb09', 'chb13', 'chb15'
                                                                          ,'chb06', 'chb18', 'chb23'))]
training_set.extend(training_seiz)
print(len(training_set))


val_set = [("patterns/normals/" + f, [1, 0]) for f in listdir("patterns/normals/") 
                if isfile(join("patterns/normals/", f)) and f.startswith(('chb07', 'chb09', 'chb13', 'chb15'))]
val_seiz = [("patterns/seizures/" + f, [0, 1]) for f in listdir("patterns/seizures/") 
                if isfile(join("patterns/seizures/", f)) and f.startswith(('chb07', 'chb09', 'chb13', 'chb15'))]
val_set.extend(val_seiz)
print(len(val_set))


test_set = [("patterns/normals/" + f, [1, 0]) for f in listdir("patterns/normals/") 
                if isfile(join("patterns/normals/", f)) and f.startswith(('chb06', 'chb18', 'chb23'))]
test_seiz = [("patterns/seizures/" + f, [0, 1]) for f in listdir("patterns/seizures/") 
                if isfile(join("patterns/seizures/", f)) and f.startswith(('chb06', 'chb18', 'chb23'))]
test_set.extend(test_seiz)
print(len(test_set))


res_training_set = get_valid_files(training_set)
res_val_set = get_valid_files(val_set)
res_test_set = get_valid_files(test_set)

print(len(res_training_set), len(res_val_set), len(res_test_set))

2161
454
258
2131 342 253


#### HEAD

In [5]:
# Dataset with labels
class EEGPatternDataset(Dataset):
    def __init__(self, file_label_pairs, window_size=512, cut_step=256):
        self.samples, self.labels = [], []
        for f, y in file_label_pairs:
            raw = mne.io.read_raw_edf(f, preload=True)
            # select only chanells above (for datasets with different channels)
            raw.pick(needed_channels)
            # filter raw record, return some numpy obj
            data = get_filtered_data(raw)
            # превращает обычный список в торч тензор
            # y = torch.tensor(label, dtype=torch.float32)
            
            for start in range(0, data.shape[1]-window_size, cut_step):
                segment = data[:, start:start+window_size]
                
                self.samples.append(segment.astype(np.float32))
                # гарантируем, что y — float32 tensor
                label_tensor = torch.tensor(y, dtype=torch.float32)
                self.labels.append(label_tensor)  # метка паттерна
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        x = torch.tensor(self.samples[idx], dtype=torch.float32)
        y = self.labels[idx]
        return x, y


In [11]:
def train_head(backbone, head, train_loader, val_loader, epochs=10, lr=1e-3, fine_tune=False):
    
    
    # saving histories
    train_losses = []
    val_losses = []
    val_accuracies = []
    precisions = []
    recalls = []
    F1s = []
    ROC_AUCs = []
    PR_AUCs = []
    
    # Заморозить backbone
    for p in backbone.parameters():
        p.requires_grad = False

    optimizer = optim.Adam(head.parameters(), lr=lr)
    # Для создания нескольких независимых друг от друга паттернов (мульти-лейбл) нужен .Sigmoid()
    # функция ниже как то применяет этот сигмоид у себя внутри
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        head.train()
        total_loss = 0

        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            # y.float().unsqueeze(1).cuda()
            x, y = x.cuda(), y.float().cuda()
            with torch.no_grad():
                z = backbone(x)
            preds = head(z)
            loss = criterion(preds, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Train loss: {total_loss / len(train_loader):.4f}")
        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Валидация
        
        # Предсказания и метки собираем по всей валидации
        all_preds = []
        all_labels = []
        head.eval()
        correct, total, val_loss = 0, 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.cuda(), y.cuda()
                z = backbone(x)
                preds = head(z)
                # метрики для анализа F1
                all_preds.append(preds.cpu())
                all_labels.append(y.cpu())
                loss = criterion(preds, y)
                val_loss += loss.item()
                # correct += ((preds > 0.5).float() == y).sum().item()
                correct += ((torch.sigmoid(preds) > 0.5).float() == y).sum().item()
                total += y.size(0)
                
        
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        val_acc = correct / total
        val_accuracies.append(val_acc)

        print(f"Epoch {epoch+1}/{epochs} |"
              f" Train loss: {avg_train_loss:.4f} | Val loss: {avg_val_loss:.4f} | Val acc: {val_acc:.3f}")
        
        all_preds = torch.cat(all_preds).numpy()
        all_labels = torch.cat(all_labels).numpy()
        # Бинаризация (0.5 порог)
        binary_preds = (all_preds > 0.5).astype(int)

        precision = precision_score(all_labels, binary_preds, average='macro')
        precisions.append(precision)
        recall = recall_score(all_labels, binary_preds, average='macro')
        recalls.append(recall)
        f1 = f1_score(all_labels, binary_preds, average='macro')
        F1s.append(f1)
        auc_macro = roc_auc_score(all_labels, all_preds, average='macro')
        ROC_AUCs.append(auc_macro)
        pr_auc = average_precision_score(all_labels, all_preds, average='macro')
        PR_AUCs.append(pr_auc)

        print(f"Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}, ROC-AUC: {auc_macro:.3f}, PR-AUC: {pr_auc:.3f}")
        

        
    # ==== 4. Fine-tuning ====
    if fine_tune:
        print("\nРазмораживаем верхние слои backbone для тонкой настройки...")
        for p in list(backbone.parameters())[-6:]: # total = 12 (4-6 to unfreeze is ok)
            p.requires_grad = True

        optimizer = optim.Adam(list(backbone.parameters())[-6:] + list(head.parameters()), lr=1e-4)

        for epoch in range(3):  # несколько эпох fine-tuning
            head.train()
            for x, y in tqdm(train_loader, desc=f"Fine-tune epoch {epoch+1}"):
                # y.float().unsqueeze(1).cuda()
                x, y = x.cuda(), y.float().cuda()
                z = backbone(x)
                preds = head(z)
                loss = criterion(preds, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            print(f"Fine-tune epoch {epoch+1} loss: {total_loss / len(train_loader):.4f}")

        print("Fine-tuning завершён!")
        

    # Сохранение весов 
    torch.save(head.state_dict(), "head.pt")
    torch.save(backbone.state_dict(), "backbone_finetuned.pt")

    # Сохранение истории метрик 
    history = pd.DataFrame({
        "train_loss": train_losses,
        "val_loss": val_losses,
        "val_accuracy": val_accuracies,
        "precision": precisions,
        "recall": recalls,
        "F1": F1s,
        "ROC-AUC": ROC_AUCs,
        "PR-AUC": PR_AUCs
    })
    history.to_csv("training_history_head.csv", index=False)

In [22]:
# TODO: 
# 3) Load backbone for 4 epoch (as the best per val-loss) -> >:( something broke >:(

In [10]:
# debugging to show right trace to find an error -> синхронный запуск, не включать при норм работе!
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [8]:
backbone = EEGNet(
    n_chans=22, 
    n_outputs=128, 
    n_times=512,
    final_conv_length='auto'
)
backbone.classify = False

# Загружаем чекпойнт с CPU, чтобы избежать CUDA assert
checkpoint = torch.load("backbone_ssl_epoch04.pt", map_location="cpu")

# Извлекаем именно веса модели
state_dict = checkpoint['model_state']

# zагружаем в модель
backbone.load_state_dict(state_dict)

# переносим на GPU
backbone = backbone.cuda()



In [9]:
# Head
# Head class for classification/regression
class Head(nn.Module):
    def __init__(self, in_dim=128, out_dim=2):  # out_dim=2 -> 2 pattern first
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, out_dim)
        )
    def forward(self, x):
        return self.fc(x)

head = Head(in_dim=128, out_dim=2).cuda()

optimizer = torch.optim.Adam(list(backbone.parameters())+list(head.parameters()), lr=1e-4)


In [12]:
# res_training_set initialized earlier
# res_val_set initialized earlier
train_ds = EEGPatternDataset(res_training_set)
val_ds = EEGPatternDataset(res_val_set)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=16)

# head initialized earlier
# backbone initialized earlier

train_head(backbone, head, train_loader, val_loader, epochs=10, fine_tune=False)

Epoch 1/10: 100%|█████████████████████████| 1567/1567 [00:01<00:00, 1052.55it/s]


Train loss: 0.4588
Epoch 1/10 | Train loss: 0.4588 | Val loss: 0.7863 | Val acc: 1.105
Precision: 0.621, Recall: 0.419, F1: 0.407, ROC-AUC: 0.668, PR-AUC: 0.647


Epoch 2/10: 100%|█████████████████████████| 1567/1567 [00:01<00:00, 1075.84it/s]


Train loss: 0.4567
Epoch 2/10 | Train loss: 0.4567 | Val loss: 0.8143 | Val acc: 1.100
Precision: 0.615, Recall: 0.418, F1: 0.390, ROC-AUC: 0.669, PR-AUC: 0.652


Epoch 3/10: 100%|█████████████████████████| 1567/1567 [00:01<00:00, 1074.54it/s]


Train loss: 0.4580
Epoch 3/10 | Train loss: 0.4580 | Val loss: 0.7680 | Val acc: 1.121
Precision: 0.653, Recall: 0.398, F1: 0.402, ROC-AUC: 0.669, PR-AUC: 0.656


Epoch 4/10: 100%|█████████████████████████| 1567/1567 [00:01<00:00, 1071.53it/s]


Train loss: 0.4540
Epoch 4/10 | Train loss: 0.4540 | Val loss: 0.8127 | Val acc: 1.107
Precision: 0.640, Recall: 0.424, F1: 0.389, ROC-AUC: 0.677, PR-AUC: 0.660


Epoch 5/10: 100%|█████████████████████████| 1567/1567 [00:01<00:00, 1051.09it/s]


Train loss: 0.4543
Epoch 5/10 | Train loss: 0.4543 | Val loss: 0.7522 | Val acc: 1.148
Precision: 0.641, Recall: 0.397, F1: 0.423, ROC-AUC: 0.664, PR-AUC: 0.650


Epoch 6/10: 100%|█████████████████████████| 1567/1567 [00:01<00:00, 1106.08it/s]


Train loss: 0.4536
Epoch 6/10 | Train loss: 0.4536 | Val loss: 0.8281 | Val acc: 1.105
Precision: 0.621, Recall: 0.435, F1: 0.410, ROC-AUC: 0.675, PR-AUC: 0.657


Epoch 7/10: 100%|█████████████████████████| 1567/1567 [00:01<00:00, 1079.63it/s]


Train loss: 0.4545
Epoch 7/10 | Train loss: 0.4545 | Val loss: 0.7897 | Val acc: 1.136
Precision: 0.630, Recall: 0.420, F1: 0.415, ROC-AUC: 0.675, PR-AUC: 0.656


Epoch 8/10: 100%|█████████████████████████| 1567/1567 [00:01<00:00, 1051.36it/s]


Train loss: 0.4524
Epoch 8/10 | Train loss: 0.4524 | Val loss: 0.8359 | Val acc: 1.099
Precision: 0.652, Recall: 0.438, F1: 0.376, ROC-AUC: 0.693, PR-AUC: 0.679


Epoch 9/10: 100%|█████████████████████████| 1567/1567 [00:01<00:00, 1100.53it/s]


Train loss: 0.4547
Epoch 9/10 | Train loss: 0.4547 | Val loss: 0.7632 | Val acc: 1.142
Precision: 0.629, Recall: 0.414, F1: 0.413, ROC-AUC: 0.678, PR-AUC: 0.659


Epoch 10/10: 100%|████████████████████████| 1567/1567 [00:01<00:00, 1109.27it/s]


Train loss: 0.4531
Epoch 10/10 | Train loss: 0.4531 | Val loss: 0.7549 | Val acc: 1.149
Precision: 0.640, Recall: 0.414, F1: 0.426, ROC-AUC: 0.674, PR-AUC: 0.656
