In [4]:
import glob
from collections import defaultdict
from pathlib import Path

import mne
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm

## Подготовка метадаты

In [5]:
eeg_files = glob.glob('/mnt/d/Study/PhD/Data/EEG/*/*.set')
df_eegs = pd.DataFrame(eeg_files, columns=['path'])

In [6]:
df_eegs['session_type'] = df_eegs['path'].apply(lambda x: x.split('/')[-2])
df_eegs['mono'] = df_eegs['path'].apply(lambda x: 'mono' in x.split('/')[-1])
df_eegs['subject_id'] = df_eegs['path'].apply(lambda x: '_'.join(x.split('/')[-1].replace('_mono', '').split('_')[:3]))

In [7]:
df_eegs.session_type.value_counts()

session_type
fon      32
other    32
own      32
Name: count, dtype: int64

## Подготовка кропов
Подготовка датасета -- кропы с меткой открытых/закрытых глаз по аннотациям.

Кропы фиксированной длины с пересечением, с отступом до и после аннотации сегмента.

In [9]:
def extract_ann_info(path, target_sfreq=500):
    """Load a .set file, return dict with path, total duration (s), ann.onset, ann.description."""
    try:
        raw = mne.io.read_raw_eeglab(path, preload=True, verbose=False)
        if raw.info["sfreq"] != target_sfreq:
            raw.resample(target_sfreq, verbose=False)
        n_samples = raw.n_times
        sfreq = raw.info["sfreq"]
        duration_sec = n_samples / sfreq
        ann = raw.annotations

        # cast onset and description to Python list for DataFrame compatibility
        onset_list = ann.onset.tolist()
        description_list = list(ann.description)
        return {
            'path': path,
            'duration_sec': duration_sec,
            'ann_onset': onset_list,
            'ann_description': description_list
        }
    except Exception as e:
        return {
            'path': path,
            'duration_sec': None,
            'ann_onset': None,
            'ann_description': None,
            'error': str(e)
        }

In [10]:
# получение метаинформации по сегментам (2.5 min / 96 eegs)
records = Parallel(n_jobs=-1, verbose=1)(
    delayed(extract_ann_info)(path, target_sfreq)
    for path in tqdm(df_eegs['path'], desc="Reading EEGs with ann info")
)

df_ann_info = pd.DataFrame(records)

Reading EEGs with ann info:   0%|          | 0/96 [00:00<?, ?it/s][Parallel(n_jobs=-1)]: Using backend LokyBackend with 32 concurrent workers.
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
Reading EEGs with ann info: 100%|██████████| 96/96 [01:12<00:00,  1.32it/s]
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
[Parallel(n_jobs=-1)]: Done  96 out of  96 | elapsed:  2.5min finished


In [11]:
df_ann_info['ann_description'] = df_ann_info['ann_description'].apply(lambda x: list(map(str, x)))

In [19]:
df_eegs = df_eegs.merge(df_ann_info, how='left', on='path')

In [20]:
df_eegs.to_parquet('../data/classification/raw_eegs.parquet', index=False)

In [21]:
df_eegs = pd.read_parquet('../data/classification/raw_eegs.parquet')

## Фильтрация сегментов

In [56]:
# Convert ann_description to list of str and rename 'S 1'/'S 2' to 'S1'/'S2'
def clean_ann_description(description):
    # Handle None or NaN
    if description is None:
        return []
    # If it's already a list, work with it
    # If not, try to convert from numpy array or string repr
    if isinstance(description, str):
        # Try to eval as a list-like string
        import ast
        try:
            description = ast.literal_eval(description)
        except Exception:
            description = [description]
    # Convert to list for non-list types (e.g., numpy array)
    desc_list = list(description)

    # Fix string formatting
    out = []
    for d in desc_list:
        d_str = str(d).strip()
        if d_str in ['S 1', 'S   1']:
            out.append('S1')
        elif d_str in ['S 2', 'S   2']:
            out.append('S2')
        else:
            out.append(d_str)
    return out

df_eegs['ann_description'] = df_eegs['ann_description'].apply(clean_ann_description)

In [57]:
df_eegs['all_segments'] = df_eegs.apply(lambda row: [float(seg_start) for seg_start in row['ann_onset']] + [float(row['duration_sec'])], axis=1)
df_eegs['segment_durations'] = df_eegs['all_segments'].apply(lambda x: [round(float(x[i+1]) - float(x[i]), 2) for i in range(len(x)-1)])
def get_segments_by_type(row, seg_type):
    segs = []
    ann_desc = row['ann_description']
    seg_starts = row['all_segments']
    for i, desc in enumerate(ann_desc):
        if desc == seg_type:
            segs.append((seg_starts[i], seg_starts[i+1]))
    return segs

df_eegs['segments_S1'] = df_eegs.apply(lambda row: get_segments_by_type(row, 'S1'), axis=1)
df_eegs['segments_S2'] = df_eegs.apply(lambda row: get_segments_by_type(row, 'S2'), axis=1)

# all durations > 50 and < 130
df_eegs['good_eeg'] = df_eegs['segment_durations'].apply(lambda x: all(50 < d < 150 for d in x))
df_eegs['bad_segments'] = df_eegs['segment_durations'].apply(lambda x: [d for d in x if d < 50 or d > 130])

In [59]:
df_eegs.to_parquet('../data/classification/raw_eegs.parquet', index=False)

In [None]:
## Сегменты
# Make frame where each row is a S1 or S2 segment (with label and EEG metadata)

segment_rows = []

for idx, row in df_eegs.iterrows():
    # S1 segments
    for seg in row.get('segments_S1', []):
        segment_rows.append({
            'path': row['path'],
            'session_type': row['session_type'],
            'mono': row['mono'],
            'subject_id': row.get('subject_id', None),
            'duration_sec': row.get('duration_sec', None),
            'good_eeg': row.get('good_eeg', None),
            'segment_type': 'S1',
            'segment_start': seg[0],
            'segment_end': seg[1],
            'segment_len': seg[1] - seg[0],
            'bad_segments_for_eeg': row.get('bad_segments', []),
        })
    # S2 segments
    for seg in row.get('segments_S2', []):
        segment_rows.append({
            'path': row['path'],
            'session_type': row['session_type'],
            'mono': row['mono'],
            'subject_id': row.get('subject_id', None),
            'duration_sec': row.get('duration_sec', None),
            'good_eeg': row.get('good_eeg', None),
            'segment_type': 'S2',
            'segment_start': seg[0],
            'segment_end': seg[1],
            'segment_len': seg[1] - seg[0],
            'bad_segments_for_eeg': row.get('bad_segments', []),
        })

df_segments = pd.DataFrame(segment_rows)


In [63]:
df_segments.to_parquet('../data/classification/raw_eegs_segments.parquet')

## Нарезка кропов

In [65]:
CROPS_DIR = '../data/classification_crops'

target_sfreq = 500  # Hz

# sliding window params
segment_sec = 5.0           # window length in seconds (crop length)
overlap_sec = 2.0           # overlap in seconds
segment_margin_sec = 5.0    # clean margin from event start/end in seconds

In [67]:
from tqdm.auto import tqdm

import mne
import numpy as np
import os

def crop_segments_for_eeg(row, crops_dir=CROPS_DIR):
    """For each segment (S1/S2) in row, resample, crop into overlapping crops, save and return rows."""
    crop_rows = []
    try:
        raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
        if raw.info['sfreq'] != target_sfreq:
            raw.resample(target_sfreq)
        # for both S1 and S2 segments:
        for seg_type in ['segments_S1', 'segments_S2']:
            segments = row.get(seg_type, [])
            stype = 'S1' if seg_type == 'segments_S1' else 'S2'
            for seg_start, seg_end in segments:
                # apply margins
                crop_start = seg_start + segment_margin_sec
                crop_end = seg_end - segment_margin_sec
                if crop_end <= crop_start:
                    continue  # skip too small segments
                t = crop_start
                while t + segment_sec <= crop_end + 1e-5:  # allow fp error
                    start_sample = int(t * target_sfreq)
                    stop_sample = int((t + segment_sec) * target_sfreq)
                    crop = raw.get_data(start=start_sample, stop=stop_sample)
                    crop_id = f"{os.path.splitext(os.path.basename(row['path']))[0]}_{stype}_{int(1000*t)}_{int(1000*(t+segment_sec))}.npy"
                    crop_path = os.path.join(crops_dir, crop_id)
                    np.save(crop_path, crop.astype(np.float32))
                    crop_rows.append({
                        'crop_path': crop_path,
                        'subject_id': row.get('subject_id', None),
                        'session_type': row.get('session_type', None),
                        'mono': row.get('mono', None),
                        'segment_type': stype,
                        'segment_orig_start': seg_start,
                        'segment_orig_end': seg_end,
                        'crop_start_time': t,
                        'crop_end_time': t + segment_sec,
                        'good_eeg': row.get('good_eeg', None),
                    })
                    t += segment_sec - overlap_sec
    except Exception as e:
        print(f"Failed {row['path']}: {e}")
    return crop_rows

os.makedirs(CROPS_DIR, exist_ok=True)

all_crop_rows = []
rows = df_eegs.to_dict('records')

# Run serially to avoid pickling error with joblib/Parallel
with tqdm(total=len(rows), desc="Cropping EEGs (serial)") as pbar:
    for row in rows:
        result = crop_segments_for_eeg(row)
        all_crop_rows.extend(result)
        pbar.update(1)

df_crops = pd.DataFrame(all_crop_rows)
df_crops.to_parquet('../data/classification/raw_eegs_crops.parquet')

Cropping EEGs (serial):   0%|          | 0/96 [00:00<?, ?it/s]

  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  warn(
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'], preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(row['path'],

In [3]:
import torch
import numpy as np
import pandas as pd

# Load the dataframe with crop segment paths
df = pd.read_parquet('../data/classification/raw_eegs_crops.parquet')
print("Number of segments:", len(df))
print("First 3 crop paths:\n", df['crop_path'].head(3).tolist())

# Load a batch of EEG crops (e.g., first 8)
batch_size = 128
batch_paths = df['crop_path'].head(batch_size).tolist()

eeg_segments = []
for p in batch_paths:
    d = np.load(p)  # shape: (n_channels, n_samples)
    eeg_segments.append(d)
eeg_segments = np.stack(eeg_segments)  # (batch_size, n_channels, n_samples)
print("EEG segments np array shape:", eeg_segments.shape)

# Convert to torch tensor
inputs = torch.from_numpy(eeg_segments.astype(np.float32))  # (batch, ch, sampl)
print("Input torch tensor shape:", inputs.shape)

# Optionally move to cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inputs = inputs.to(device)

# Import/load the model (see EEGNet from eda_eegs.ipynb)
# For demonstration, assume n_channels and n_samples:
n_channels = inputs.shape[1]
in_samples = inputs.shape[2]
n_classes = 2  # or as appropriate

# Define model (copy-paste EEGNet or import if available)
class EEGNet(torch.nn.Module):
    def __init__(self, n_channels=128, in_samples=2500, n_classes=2):
        super(EEGNet, self).__init__()
        self.conv1 = torch.nn.Conv1d(n_channels, 64, kernel_size=7, stride=1, padding=3)
        self.bn1 = torch.nn.BatchNorm1d(64)
        self.conv2 = torch.nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2, groups=64)
        self.bn2 = torch.nn.BatchNorm1d(128)
        self.conv3 = torch.nn.Conv1d(128, 256, kernel_size=1)
        self.bn3 = torch.nn.BatchNorm1d(256)
        self.pool = torch.nn.AdaptiveAvgPool1d(32)
        self.dropout = torch.nn.Dropout(0.5)
        self.fc1 = torch.nn.Linear(256 * 32, 128)
        self.fc2 = torch.nn.Linear(128, n_classes)
    def forward(self, x):
        x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
        x = torch.nn.functional.relu(self.bn2(self.conv2(x)))
        x = torch.nn.functional.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = EEGNet(n_channels=n_channels, in_samples=in_samples, n_classes=n_classes).to(device)

with torch.no_grad():
    out = model(inputs)
print("Model output shape:", out.shape)



Number of segments: 6841
First 3 crop paths:
 ['../data/classification_crops/Co_y6_003_fon1_S1_41292_46292.npy', '../data/classification_crops/Co_y6_003_fon1_S1_44292_49292.npy', '../data/classification_crops/Co_y6_003_fon1_S1_47292_52292.npy']
EEG segments np array shape: (128, 128, 2500)
Input torch tensor shape: torch.Size([128, 128, 2500])
Model output shape: torch.Size([128, 2])


In [4]:
df_crops["crop_len"] = df_crops["crop_end_time"] - df_crops["crop_start_time"]

NameError: name 'df_crops' is not defined

In [104]:
import os
import numpy as np
import collections

# Папка с .npy-кропами
crops_dir = "../data/classification_crops/"  # замените при необходимости

DESIRED_CROP_LEN = 2500

def load_and_fix_crop(path, desired_len=DESIRED_CROP_LEN, save_if_changed=True):
    """
    Загружает .npy, корректирует длину до desired_len (обрезка или паддинг), сохраняет если изменено.
    Возвращает итоговую форму.
    """
    arr = np.load(path)
    orig_shape = arr.shape
    n_chans, n_samples = arr.shape

    if n_samples == desired_len:
        return orig_shape  # ок, ничего не делаем

    # Исправляем
    if n_samples > desired_len:
        arr_fixed = arr[:, :desired_len]
    else:  # n_samples < desired_len
        pad_width = desired_len - n_samples
        arr_fixed = np.pad(arr, ((0, 0), (0, pad_width)), mode='constant')

    # Сохраняем обратно, если нужно
    if save_if_changed:
        np.save(path, arr_fixed)
    return arr_fixed.shape

all_shapes = []
crop_paths = []

# Сбор кропов и проверка формы
for root, _, files in os.walk(crops_dir):
    for f in files:
        if f.endswith('.npy'):
            path = os.path.join(root, f)
            try:
                shape = load_and_fix_crop(path, DESIRED_CROP_LEN, save_if_changed=True)
                all_shapes.append(shape)
                crop_paths.append(path)
            except Exception as e:
                print(f"Ошибка при обработке {path}: {e}")

# Статистика форм после коррекции
shape_counts = collections.Counter(all_shapes)

print("Статистика по шейпам .npy файлов (после коррекции):")
for shape, count in shape_counts.items():
    print(f"Shape {shape}: {count} файлов")

if all_shapes:
    arr_shapes = np.array(all_shapes)
    for axis in range(arr_shapes.shape[1]):
        min_dim = arr_shapes[:, axis].min()
        max_dim = arr_shapes[:, axis].max()
        print(f"Ось {axis}: min={min_dim}, max={max_dim}")
else:
    print("Файлы .npy не найдены в указанной директории.")

Статистика по шейпам .npy файлов (после коррекции):
Shape (128, 2500): 6841 файлов
Ось 0: min=128, max=128
Ось 1: min=2500, max=2500


In [5]:
df_crops = pd.read_parquet('../data/classification/raw_eegs_crops.parquet')

In [6]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import balanced_accuracy_score, accuracy_score, f1_score

# --- Блок проверки датасета S1/S2 и базовых свойств ---

# Проверим уникальные значения segment_type
print("Уникальные значения df_crops['segment_type']:", df_crops["segment_type"].unique())

# Проверим распределение классов
print("Распределение S1/S2:")
print(df_crops["segment_type"].value_counts())

# label = 0 для S1, label = 1 для S2
label_map = {"S1": 0, "S2": 1}
df_crops = df_crops.copy()
df_crops["label"] = df_crops["segment_type"].map(label_map)

# Убедимся, что все метки корректно проставлены
print("Уникальные метки label: ", df_crops["label"].unique())
print("Распределение по label:")
print(df_crops["label"].value_counts())

# Проверим, что соответствие между segment_type и label без пропусков
assert df_crops["label"].isna().sum() == 0, "В label есть пропуски!"

# Проверим, что в каждой записи crop_path существует
missing_crop_paths = df_crops["crop_path"].apply(lambda p: not os.path.exists(p))
n_missing = missing_crop_paths.sum()
if n_missing > 0:
    print(f"Внимание: {n_missing} файлов по crop_path не найдены!")
    print(df_crops[missing_crop_paths][["crop_path", "subject_id", "segment_type"]])
else:
    print("Все crop_path существуют.")

# Проверим сколько уникальных subject_id:
subj_counts = df_crops["subject_id"].value_counts()
print(f"Уникальных subject_id: {df_crops['subject_id'].nunique()}")
print("Топ-5 субъектов по числу crop-ов:")
print(subj_counts.head())

# --- Разделяем на train/val по subject_id ---
unique_subjects = df_crops["subject_id"].unique()
np.random.seed(42)
np.random.shuffle(unique_subjects)
train_subjects = unique_subjects[:23]
val_subjects = unique_subjects[23:]

print(f"Train subjects:\n{sorted(train_subjects)} (n={len(train_subjects)})")
print(f"Val subjects:\n{sorted(val_subjects)} (n={len(val_subjects)})")

train_mask = df_crops["subject_id"].isin(train_subjects)
val_mask = df_crops["subject_id"].isin(val_subjects)
df_train = df_crops[train_mask].reset_index(drop=True)
df_val = df_crops[val_mask].reset_index(drop=True)

print(f"Train crops: {len(df_train)}; Val crops: {len(df_val)}")

# Проверим баланс классов в train и val
print("Train распределение label:")
print(df_train["label"].value_counts())
print("Val распределение label:")
print(df_val["label"].value_counts())

# Дополнительно: не пересекаются ли subject_id в train/val?
intersection = set(df_train["subject_id"]).intersection(set(df_val["subject_id"]))
print(f"Пересечение subject_id между train и val: {intersection}")

# --- Проверяем форму и данные в .npy ---
def load_eeg_npy(path):
    arr = np.load(path)
    return arr

# Сформируем shape-чеки
X_train_shapes = []
for path in df_train["crop_path"]:
    arr = load_eeg_npy(path)
    X_train_shapes.append(arr.shape)
X_val_shapes = []
for path in df_val["crop_path"]:
    arr = load_eeg_npy(path)
    X_val_shapes.append(arr.shape)

print("Train shapes (уникальные):", set(X_train_shapes))
print("Val shapes (уникальные):", set(X_val_shapes))

# Берём только если все сэмплы одинаковой формы:
if len(set(X_train_shapes)) == 1 and len(set(X_val_shapes)) == 1:
    X_train = np.stack([load_eeg_npy(path) for path in df_train["crop_path"]])
    y_train = df_train["label"].values
    X_val = np.stack([load_eeg_npy(path) for path in df_val["crop_path"]])
    y_val = df_val["label"].values

    n_channels = X_train.shape[1]
    in_samples = X_train.shape[2]
    n_classes = 2
    print(f"Финальный X_train: {X_train.shape}, y_train: {y_train.shape}")
    print(f"Финальный X_val: {X_val.shape}, y_val: {y_val.shape}")
else:
    print("В train или val есть .npy разной формы! Проверьте исходные данные.")

# Дополнительно: sanity check — совмещаем label и содержимое
print("Пример: перваые 5 меток train/val: ", y_train[:5], y_val[:5])
print("Пример формы первого X_train:", X_train[0].shape)



Уникальные значения df_crops['segment_type']: ['S1' 'S2']
Распределение S1/S2:
segment_type
S1    3810
S2    3031
Name: count, dtype: int64
Уникальные метки label:  [0 1]
Распределение по label:
label
0    3810
1    3031
Name: count, dtype: int64


NameError: name 'os' is not defined

In [7]:
# Лосс растёт сразу с первой эпохи, а качество держится около 0.6. Давайте попробуем изменить тренировочный процесс:
# 1. Попробуем уменьшить learning rate и batch size, чтобы оптимизация была стабильнее.
# 2. Добавим нормализацию перед моделью (по каналам) вне DataLoader, на весь датасет.
# 3. Упростим архитектуру: избавимся от агрессивного dropout, попробуем убрать логарифмирование/степени.
# 4. Проведём sanity check: поменяем инициализацию последнего слоя на меньшие значения.
# 5. Введём фиксированное seed для torch/numpy.
# 6. Еще: scheduler зафризим, чтобы не мешал коротким датасетам.

import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Seed fix
import random
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# 1. Глобальная пер-семпл нормализация X_train/X_val по каналам
def channelwise_zscore(X):
    # X: (N, ch, t)
    mean = X.mean(axis=(0, 2), keepdims=True)
    std = X.std(axis=(0, 2), keepdims=True) + 1e-6
    return (X - mean) / std

X_train_norm = channelwise_zscore(X_train)
X_val_norm = channelwise_zscore(X_val)

class EEGDataset(Dataset):
    """EEG dataset без лишней batch-статистики; z-score каналов уже был применён."""
    def __init__(self, X, y):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.tensor(self.y[idx])

# 2. Модифицируем сеть: меньше dropout, проще pipeline
class SimplerShallowConvNet(torch.nn.Module):
    def __init__(self, n_channels=128, in_samples=2500, n_classes=2):
        super().__init__()
        self.conv_time = torch.nn.Conv1d(n_channels, 32, kernel_size=25, stride=1, padding=12, bias=False)
        self.bn = torch.nn.BatchNorm1d(32)
        self.pool = torch.nn.AvgPool1d(75, stride=15, padding=37)
        self.dropout = torch.nn.Dropout(0.15)
        # Вычислям выходной размер
        pool_len = (in_samples + 2*37 - 75) // 15 + 1
        self.fc = torch.nn.Linear(32 * pool_len, n_classes)
        # Инициализация головы для лучших стартовых значений
        torch.nn.init.xavier_uniform_(self.fc.weight, gain=0.1)

    def forward(self, x):
        x = self.conv_time(x)
        x = self.bn(x)
        # Убираем степени и логарифм - это часто мешает градиентам на старте
        x = torch.relu(x)
        x = self.pool(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

batch_size = 64  # был 512 → уменьшаем для повторяемости/стабильности лосса
train_ds = EEGDataset(X_train_norm, y_train)
val_ds = EEGDataset(X_val_norm, y_val)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

model = SimplerShallowConvNet(n_channels=n_channels, in_samples=in_samples, n_classes=n_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)  # был 1e-3
# Отключаем scheduler (или делаем его очень терпеливым для диагностики)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1000)
criterion = torch.nn.CrossEntropyLoss()

print(f"X_train_norm shape: {X_train_norm.shape}, X_val_norm shape: {X_val_norm.shape}")
print(f"Train class counts: {np.bincount(y_train)}, Val class counts: {np.bincount(y_val)}")
print("Model summary:")
print(model)

num_epochs = 30
patience = 30
best_bal_acc = 0
epochs_no_improve = 0
best_state_dict = None

train_loss_hist, val_loss_hist = [], []
train_acc_hist, val_acc_hist = [], []
train_bal_acc_hist, val_bal_acc_hist = [], []
train_f1_hist, val_f1_hist = [], []

for epoch in range(1, num_epochs + 1):
    model.train()
    train_losses, all_train_preds, all_train_true = [], [], []

    train_bar = tqdm(enumerate(train_dl), total=len(train_dl), desc=f"Epoch {epoch} [Train]", leave=False)
    for batch_i, (xb, yb) in train_bar:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        preds = out.argmax(dim=1).detach().cpu().numpy()
        all_train_preds.append(preds)
        all_train_true.append(yb.detach().cpu().numpy())
        train_bar.set_postfix({"Batch loss": f"{loss.item():.4f}"})

    train_y_pred = np.concatenate(all_train_preds)
    train_y_true = np.concatenate(all_train_true)
    train_acc = accuracy_score(train_y_true, train_y_pred)
    train_bal_acc = balanced_accuracy_score(train_y_true, train_y_pred)
    train_f1 = f1_score(train_y_true, train_y_pred)
    train_loss_epoch = np.mean(train_losses)
    train_loss_hist.append(train_loss_epoch)
    train_acc_hist.append(train_acc)
    train_bal_acc_hist.append(train_bal_acc)
    train_f1_hist.append(train_f1)

    # Validation
    model.eval()
    all_val_preds, all_val_true = [], []
    val_losses = []
    with torch.no_grad():
        val_bar = tqdm(val_dl, total=len(val_dl), desc=f"Epoch {epoch} [Val]", leave=False)
        for xb, yb in val_bar:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            val_loss = criterion(logits, yb)
            val_losses.append(val_loss.item())
            preds = logits.argmax(dim=1).cpu().numpy()
            all_val_preds.append(preds)
            all_val_true.append(yb.cpu().numpy())
            val_bar.set_postfix({"Batch loss": f"{val_loss.item():.4f}"})

    if len(all_val_preds) > 0:
        val_y_pred = np.concatenate(all_val_preds)
        val_y_true = np.concatenate(all_val_true)
        bal_acc = balanced_accuracy_score(val_y_true, val_y_pred)
        acc = accuracy_score(val_y_true, val_y_pred)
        f1 = f1_score(val_y_true, val_y_pred)
        val_loss_epoch = np.mean(val_losses)
    else:
        val_loss_epoch = bal_acc = acc = f1 = float('nan')

    val_loss_hist.append(val_loss_epoch)
    val_acc_hist.append(acc)
    val_bal_acc_hist.append(bal_acc)
    val_f1_hist.append(f1)

    # Scheduler (no-op under patience=1000)
    scheduler.step(val_loss_epoch)

    print(f"\nEpoch {epoch:2d} SUMMARY:")
    print(f"  Train Loss: {train_loss_epoch:.4f} | Train Acc: {train_acc:.4f} | Train BA: {train_bal_acc:.4f} | Train F1: {train_f1:.4f}")
    print(f"  Val   Loss: {val_loss_epoch:.4f} | Val   Acc: {acc:.4f} | Val   BA: {bal_acc:.4f} | Val   F1: {f1:.4f}\n")
    print(f"  Val confusion: {np.bincount(val_y_pred) if len(all_val_preds) else 'N/A'}")

    # Early stopping по валид баланс-acc
    if bal_acc > best_bal_acc + 1e-4:
        best_bal_acc = bal_acc
        best_state_dict = model.state_dict()
        epochs_no_improve = 0
        print(f"  --> New best model (val balanced acc improved to {best_bal_acc:.4f})\n")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping after {epoch} epochs (no val BA improvement for {patience} epochs).")
            break

if best_state_dict:
    model.load_state_dict(best_state_dict)
    print(f"Best balanced accuracy on validation: {best_bal_acc:.4f}")
else:
    print("No improvement on validation set.")

# --- Графики ---
epochs_ran = len(train_loss_hist)
fig, axs = plt.subplots(2, 2, figsize=(13, 9))

axs[0,0].plot(range(1, epochs_ran+1), train_loss_hist, label='Train Loss')
axs[0,0].plot(range(1, epochs_ran+1), val_loss_hist, label='Val Loss')
axs[0,0].set_title("Loss")
axs[0,0].set_xlabel("Epoch")
axs[0,0].legend()
axs[0,0].grid()

axs[0,1].plot(range(1, epochs_ran+1), train_bal_acc_hist, label='Train BA')
axs[0,1].plot(range(1, epochs_ran+1), val_bal_acc_hist, label='Val BA')
axs[0,1].set_title("Balanced Accuracy")
axs[0,1].set_xlabel("Epoch")
axs[0,1].legend()
axs[0,1].grid()

axs[1,0].plot(range(1, epochs_ran+1), train_acc_hist, label='Train Acc')
axs[1,0].plot(range(1, epochs_ran+1), val_acc_hist, label='Val Acc')
axs[1,0].set_title("Accuracy")
axs[1,0].set_xlabel("Epoch")
axs[1,0].legend()
axs[1,0].grid()

axs[1,1].plot(range(1, epochs_ran+1), train_f1_hist, label='Train F1')
axs[1,1].plot(range(1, epochs_ran+1), val_f1_hist, label='Val F1')
axs[1,1].set_title("F1 Score")
axs[1,1].set_xlabel("Epoch")
axs[1,1].legend()
axs[1,1].grid()

fig.tight_layout()
plt.show()

NameError: name 'X_train' is not defined

In [None]:
1

In [76]:
df_segments.subject_id.nunique()

29

In [36]:
dict(df_eegs.iloc[-1])

{'path': '/mnt/d/Study/PhD/Data/EEG/own/Mor_y1_004_own_face.set',
 'session_type': 'own',
 'mono': np.False_,
 'subject_id': 'Mor_y1_004',
 'duration_sec': np.float64(746.172),
 'ann_onset': array([  5.4095,   6.2615,   6.8955,  19.159 , 137.659 , 260.059 ,
        378.359 , 498.859 , 618.359 ]),
 'ann_description': array(['boundary', 'boundary', 'boundary', 'S1', 'S2', 'S1', 'S2', 'S1',
        'S2'], dtype=object),
 'all_segments': [5.4095,
  6.2615,
  6.8955,
  19.159,
  137.659,
  260.059,
  378.359,
  498.859,
  618.359,
  746.172],
 'segment_durations': [0.85,
  0.63,
  12.26,
  118.5,
  122.4,
  118.3,
  120.5,
  119.5,
  127.81],
 'good_eeg': np.False_,
 'bad_segments': [0.85, 0.63, 12.26],
 'segments_S1': [(19.159, 137.659), (260.059, 378.359), (498.859, 618.359)],
 'segments_S2': [(137.659, 260.059), (378.359, 498.859), (618.359, 746.172)]}

In [None]:
# состояния глаз

# 'S 1' -- closed
# 'S 2' -- open

# boundary -- skip

Unnamed: 0,path,session_type,subject_id,mono,duration_sec,ann_onset,ann_description
0,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_003_fon1.set,fon,Co_y6_003,False,599.032,"[36.292, 132.311, 218.231, 305.396, 399.016, 4...","[S 1, S 2, S 1, S 2, S 1, S 2]"
1,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_008_fon1.set,fon,Co_y6_008,False,718.248,"[8.361, 127.894, 244.88, 360.293, 478.137, 598...","[S 1, S 2, S 1, S 2, S 1, S 2]"
2,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_009_fon1.set,fon,Co_y6_009,False,756.598,"[11.787, 126.276, 230.197, 339.358, 425.819, 5...","[S 1, S 2, S 1, S 2, S 1, S 2]"
3,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_010_fon1.set,fon,Co_y6_010,False,616.164,"[27.844, 141.183, 252.179, 351.057, 422.676, 5...","[S 1, S 2, S 1, S 2, S 1, S 2]"
4,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_013_Fon1.set,fon,Co_y6_013,False,366.824,"[2.02, 98.647, 143.365, 205.839, 244.698, 316....","[S 1, S 2, S 1, S 2, S 1, S 2]"
...,...,...,...,...,...,...,...
91,/mnt/d/Study/PhD/Data/EEG/own/Co_y6_mono_002_o...,own,Co_y6_002,True,669.910,"[11.769, 115.769, 238.769, 357.769, 481.769, 5...","[S 1, S 2, S 1, S 2, S 1, S 2]"
92,/mnt/d/Study/PhD/Data/EEG/own/Mor_y1_001_own_f...,own,Mor_y1_001,False,731.932,"[31.9455, 32.2, 148.3, 206.3135, 267.3, 274.86...","[boundary, S1, S2, boundary, S1, boundary, S2,..."
93,/mnt/d/Study/PhD/Data/EEG/own/Mor_y1_002_own_f...,own,Mor_y1_002,False,676.492,"[22.0, 134.0, 139.6125, 164.6385, 183.4675, 24...","[S1, S2, boundary, boundary, boundary, S1, bou..."
94,/mnt/d/Study/PhD/Data/EEG/own/Mor_y1_003_own_f...,own,Mor_y1_003,False,736.544,"[8.6435, 11.8305, 17.712, 135.912, 256.912, 27...","[boundary, boundary, S1, S2, S1, boundary, bou..."
