# Нарезка кропов для обучения модели

In [None]:
from dataclasses import dataclass, field
import os
import warnings

import mne
import numpy as np
import pandas as pd
from tqdm import tqdm


In [2]:
df_raw = pd.read_parquet('../../data/classification/raw/segments_split_v1.parquet')

In [3]:
# known warning suppression

import warnings
import sys
import os

# Suppress warnings from Python, numpy, and C-level (including joblib workers)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
os.environ["PYTHONWARNINGS"] = "ignore"
if hasattr(sys, 'warnoptions'):
    sys.warnoptions = []

def block_warnings_in_worker():
    import warnings, os
    warnings.filterwarnings("ignore", category=UserWarning)
    warnings.filterwarnings("ignore", category=RuntimeWarning)
    os.environ["PYTHONWARNINGS"] = "ignore"

# Patch: run warning suppression inside the function/context run by joblib
from functools import wraps

def suppress_worker_warnings(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        block_warnings_in_worker()
        return func(*args, **kwargs)
    return wrapper

## Нарезка и сохранение кропов

In [4]:
@dataclass
class EEGCropsConfig:
    # Preprocessing params
    target_sfreq: int = 500  # Hz
    notch_filter: list[int] = field(default_factory=lambda: [50, 60])
    bandpass_filter: list[float] = field(default_factory=lambda: [0.5, 100])

    # Crop params
    segment_sec: float = 5.0                # window length in seconds (crop length)
    overlap_sec: float = 0.0                # overlap in seconds
    segment_margin_sec: float = 5.0         # clean margin from event start/end in seconds

# channels
EEG_USEFUL_30: list[str] = [
    'Fp1', 'Fp2',
    'AF3', 'AF4',
    'F7', 'F3', 'Fz', 'F4', 'F8',
    'FC1', 'FCz', 'FC2',
    'T7', 'C3', 'C4', 'T8',
    'CP1', 'CPz', 'CP2',
    'P7', 'P3', 'Pz', 'P4', 'P8',
    'PO3', 'POz', 'PO4',
    'O1', 'Oz', 'O2',
]
EEG_USEFUL_62: list[str] = [
    'Fp1', 'Fp2',
    'AF3', 'AF4',
    'F7', 'F3', 'Fz', 'F4', 'F8',
    'FC1', 'FCz', 'FC2',
    'T7', 'C3', 'C4', 'T8',
    'CP1', 'CPz', 'CP2',
    'P7', 'P3', 'Pz', 'P4', 'P8',
    'PO3', 'POz', 'PO4',
    'O1', 'Oz', 'O2',
    # extra frontal / prefrontal
    'AF7', 'AF8', 'AFp1', 'AFp2',
    'F1', 'F2', 'F5', 'F6',
    # extra fronto-central (NOTE: FC5/FC6 removed to make total = 64)
    'FT7', 'FT8', 'FC3', 'FC4',
    # extra temporal / temporo-parietal
    'FT9', 'FT10', 'TP7', 'TP8', 'TP9', 'TP10',
    # extra central
    'C1', 'C2', 'C5', 'C6',
    # extra centro-parietal
    'CP3', 'CP4', 'CP5', 'CP6',
    # extra parietal + parieto-occipital
    'P1', 'P2', 'P5', 'P6', 'PO7', 'PO8',
]


crop_config = EEGCropsConfig()


In [72]:
from typing import Any

from functools import wraps
from collections import Counter

def ignore_mne_warnings(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message="The data contains 'boundary' events, indicating data discontinuities. Be cautious of filtering and epoching around these events.", category=RuntimeWarning)
            warnings.filterwarnings("ignore", message="Complex objects (like classes) are not supported. They are imported on a best effort base but your mileage will vary.", category=UserWarning)
            return func(*args, **kwargs)
    return wrapper

def process_full_segments(raw, crop_config, segment_list, good_mask, save_info):
    crops = []
    target_sfreq = crop_config.target_sfreq
    segment_sec = crop_config.segment_sec
    required_samples = int(round(target_sfreq * segment_sec))
    eps = 1e-5

    for i, (event, is_good) in enumerate(zip(segment_list, good_mask)):
        start, end = event
        crop_start = start + crop_config.segment_margin_sec
        crop_end = end - crop_config.segment_margin_sec
        t = crop_start
        while t + segment_sec <= crop_end + eps:
            # Compute sample indices to get exactly required_samples
            start_sample = int(round(t * target_sfreq))
            stop_sample = start_sample + required_samples

            crop = raw.get_data(start=start_sample, stop=stop_sample)

            # Double-check crop shape
            if crop.shape[1] != required_samples:
                print(f"WARNING: Crop at t={t:.3f} has shape[1]={crop.shape[1]} (should be {required_samples}). Adjusting...")

                if crop.shape[1] > required_samples:
                    crop = crop[:, :required_samples]
                else:  # pad with zeros if too short (rare)
                    pad_shape = (crop.shape[0], required_samples - crop.shape[1])
                    crop = np.concatenate([crop, np.zeros(pad_shape, dtype=crop.dtype)], axis=1)

            crop_info = {
                # инфа о кропе
                "crop_start_time": t,
                "crop_end_time": t + segment_sec,
                "event_type": save_info.get('event_type', 'evt'),
            }

            # Correctly source info from save_info, with None check (to avoid NoneType errors)
            event_type = save_info.get('event_type', 'evt') if save_info else 'evt'
            subject_id = save_info.get('subject_id') if save_info else None
            session_type = save_info.get('session_type') if save_info else None
            path = save_info.get('path') if save_info else None
            crops_dir = save_info.get('crops_dir') if save_info else None

            # Check that path and crops_dir are not None
            if path is None or crops_dir is None:
                print(f"Skipping crop {i}: missing 'path' or 'crops_dir' (path={path}, crops_dir={crops_dir})")
                t += segment_sec - crop_config.overlap_sec
                continue

            base_fn = os.path.splitext(os.path.basename(path))[0]
            crop_id = f"{base_fn}_{session_type}_{event_type}_{i}_{int(1000*t)}_{int(1000*(t+segment_sec))}.npy"
            crop_path = os.path.join(crops_dir, crop_id)
            np.save(crop_path, crop.astype(np.float32))

            crop_info['crop_path'] = crop_path
            crop_info['path'] = path  # добавляем исходный путь до ээг

            crops.append(crop_info)

            t += segment_sec - crop_config.overlap_sec
    return crops

@ignore_mne_warnings
def crop_eeg_full_events(
    df_segments: pd.DataFrame,
    crop_config,
    channels_to_keep: list[str],
    crops_dir: str
) -> pd.DataFrame:
    os.makedirs(crops_dir, exist_ok=True)

    target_sfreq = crop_config.target_sfreq
    notch_filter = crop_config.notch_filter
    bandpass_filter = crop_config.bandpass_filter

    missing_channels_counter = Counter()
    channels_stats_counter = Counter()
    all_crop_rows = []

    df_iter = df_segments.itertuples(index=False)
    with tqdm(total=len(df_segments), desc="Cropping EEGs from full events") as pbar:
        for row in df_iter:
            path = getattr(row, 'path', None)
            segs_S1 = getattr(row, 'segments_S1', None)
            segs_S2 = getattr(row, 'segments_S2', None)
            good_S1 = getattr(row, 'good_S1', [])
            good_S2 = getattr(row, 'good_S2', [])

            # Defensive check for path
            if path is None:
                print(f"Skipping row: missing EEG file path.")
                pbar.update(1)
                continue

            # save_info -- информация о файле и сессии, никакой инфы о сегменте
            save_info_base = {
                "path": path,
                "crops_dir": crops_dir,
                "subject_id": getattr(row, 'subject_id', None),
                "session_type": getattr(row, 'session_type', None)
            }

            try:
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", message="The data contains 'boundary' events, indicating data discontinuities.", category=RuntimeWarning)
                    warnings.filterwarnings("ignore", message="Complex objects (like classes) are not supported.", category=UserWarning)
                    raw = mne.io.read_raw_eeglab(path, preload=True, verbose=False)
                channels_present = [ch for ch in channels_to_keep if ch in raw.ch_names]
                missing_channels = [ch for ch in channels_to_keep if ch not in channels_present]
                key = tuple(sorted(missing_channels))
                if key not in missing_channels_counter:
                    missing_channels_counter[key] = 0
                missing_channels_counter[key] += 1
                channels_stats_counter[len(channels_present)] = channels_stats_counter.get(len(channels_present), 0) + 1

                raw.pick(channels_present)

                if raw.info['sfreq'] != target_sfreq:
                    raw.resample(target_sfreq)

                raw.filter(l_freq=bandpass_filter[0], h_freq=bandpass_filter[1], fir_design='firwin', verbose=False)

                if isinstance(notch_filter, (list, tuple)):
                    freqs_to_notch = notch_filter
                else:
                    freqs_to_notch = [notch_filter]
                nyquist = raw.info['sfreq'] / 2.
                used_freqs = [f for f in freqs_to_notch if f < nyquist]
                if used_freqs:
                    raw.notch_filter(freqs=used_freqs, fir_design='firwin', verbose=False)

                # Handle S1 segments
                if segs_S1 is not None and len(segs_S1) > 0:
                    save_info_S1 = dict(save_info_base)
                    save_info_S1['event_type'] = 'S1'
                    crop_rows_S1 = process_full_segments(
                        raw, crop_config, segs_S1, good_S1, save_info_S1
                    )
                    all_crop_rows.extend(crop_rows_S1)

                # Handle S2 segments
                if segs_S2 is not None and len(segs_S2) > 0:
                    save_info_S2 = dict(save_info_base)
                    save_info_S2['event_type'] = 'S2'
                    crop_rows_S2 = process_full_segments(
                        raw, crop_config, segs_S2, good_S2, save_info_S2
                    )
                    all_crop_rows.extend(crop_rows_S2)

            except Exception as e:
                print(f"Failed {path}: {e}")
            pbar.update(1)

    df_crops = pd.DataFrame(all_crop_rows)

    if df_crops.empty:
        print("!!! Кропы не были сохранены или данные отсутствуют: итоговый df пустой !!!")

    print("=== Статистика числа каналов в crop ===")
    for n_ch, count in sorted(channels_stats_counter.items()):
        print(f"{n_ch} каналов: {count} кропов")

    print("\n=== Топ-10 самых частых комбинаций отсутствующих каналов ===")
    missed_sorted = Counter(missing_channels_counter).most_common(10)
    for miss, count in missed_sorted:
        if len(miss) == 0:
            continue
        print(f"{count} кропов без каналов: {miss}")

    return df_crops

In [73]:
save_dir = '/Users/whatislove/study/phd/data/processed/crops_v1_c1_500hz_30ch'

In [74]:
df_crops = crop_eeg_full_events(df_raw, crop_config, EEG_USEFUL_30, save_dir)

Cropping EEGs from full events: 100%|██████████| 80/80 [02:07<00:00,  1.60s/it]

=== Статистика числа каналов в crop ===
30 каналов: 80 кропов

=== Топ-10 самых частых комбинаций отсутствующих каналов ===





In [75]:
df_final = pd.merge(df_raw, df_crops, on='path', how='left')

In [76]:
df_final.to_parquet('../../data/classification/raw/crops_500_30_v1.parquet', index=False)

In [3]:
import pandas as pd
df_final = pd.read_parquet('../../data/classification/raw/crops_500_30_v1.parquet')

In [59]:
## Датасет для обучения

In [9]:
columns = ['path', 'crop_path', 'subject_id', 'session_type', 'crop_start_time', 'crop_end_time', 'event_type', 'good_S1', 'good_S2', 'segments_S1', 'segments_S2']

df_train = df_final[columns].loc[df_final['split'] == 'train']
df_valid = df_final[columns].loc[df_final['split'] == 'valid']

In [None]:
## метки, классификация состояния глаз S1/S2 открыты закрыты

# Сделаем target переменную: если event_type == 'S1', то 0, если 'S2', то 1
df_train['target'] = (df_train['event_type'] == 'S2').astype(int)
df_valid['target'] = (df_valid['event_type'] == 'S2').astype(int)

In [80]:
df_train.to_parquet('../../data/classification/processed/crops_500_30_v1_train.parquet', index=False)
df_valid.to_parquet('../../data/classification/processed/crops_500_30_v1_valid.parquet', index=False)

In [14]:
def make_target(row):
    if row['event_type'] == 'S2':
        return 0
    elif row['event_type'] == 'S1':
        if row.get('session_type') == 'fon':
            return 1
        elif row.get('session_type') == 'own':
            return 2
        elif row.get('session_type') == 'other':
            return 3
        else:
            return 1  # если ни один из признаков не выставлен, кладём как "other" (задайте по необходимости)
    else:
        return -1  # если вообще не S1/S2

df_train['target'] = df_train.apply(make_target, axis=1)
df_valid['target'] = df_valid.apply(make_target, axis=1)

In [16]:
df_train.to_parquet('../../data/classification/processed/crops_500_30_v1_state_train.parquet', index=False)
df_valid.to_parquet('../../data/classification/processed/crops_500_30_v1_state_valid.parquet', index=False)

In [None]:
3:1:1:1

In [15]:
df_train.target.value_counts(), df_valid.target.value_counts()

(target
 0    3128
 3    1232
 1    1169
 2    1085
 Name: count, dtype: int64,
 target
 0    775
 3    351
 1    275
 2    261
 Name: count, dtype: int64)

## Проверка датасета

In [84]:
import torch
import numpy as np
import pandas as pd

# Загружаем датафрейм с путями к кропам
df = pd.read_parquet('../../data/classification/processed/crops_500_30_v1_train.parquet')
print("Number of segments:", len(df))
print("First 3 crop paths:\n", df['crop_path'].head(3).tolist())

# Загружаем батч из 128 EEG-кропов
batch_size = 128
batch_paths = df['crop_path'].head(batch_size).tolist()

eeg_segments = []
expected_n_channels = 30
for p in batch_paths:
    d = np.load(p)  # shape: (n_channels, n_samples)
    if d.shape[0] != expected_n_channels:
        raise ValueError(f"Файл {p} имеет {d.shape[0]} каналов вместо 30!")
    eeg_segments.append(d)
eeg_segments = np.stack(eeg_segments)  # (batch_size, 62, n_samples)
print("EEG segments np array shape:", eeg_segments.shape)

# Преобразуем в torch tensor
inputs = torch.from_numpy(eeg_segments.astype(np.float32))  # (batch, 64, samples)
print("Input torch tensor shape:", inputs.shape)

# При необходимости на cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inputs = inputs.to(device)

n_channels = expected_n_channels
in_samples = inputs.shape[2]
n_classes = 2  # или другое число классов

# Модель EEGNet для 30 каналов
class EEGNet(torch.nn.Module):
    def __init__(self, n_channels=30, in_samples=2500, n_classes=2):
        super(EEGNet, self).__init__()
        # Вход: (batch, 64, samples)
        self.conv1 = torch.nn.Conv1d(n_channels, 64, kernel_size=7, stride=1, padding=3)
        self.bn1 = torch.nn.BatchNorm1d(64)
        # groups=64 для depthwise-свёртки
        self.conv2 = torch.nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2, groups=64)
        self.bn2 = torch.nn.BatchNorm1d(128)
        self.conv3 = torch.nn.Conv1d(128, 256, kernel_size=1)
        self.bn3 = torch.nn.BatchNorm1d(256)
        self.pool = torch.nn.AdaptiveAvgPool1d(32)
        self.dropout = torch.nn.Dropout(0.5)
        self.fc1 = torch.nn.Linear(256 * 32, 128)
        self.fc2 = torch.nn.Linear(128, n_classes)
    def forward(self, x):
        x = torch.nn.functional.relu(self.bn1(self.conv1(x)))
        x = torch.nn.functional.relu(self.bn2(self.conv2(x)))
        x = torch.nn.functional.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = EEGNet(n_channels=n_channels, in_samples=in_samples, n_classes=n_classes).to(device)

with torch.no_grad():
    out = model(inputs)
print("Model output shape:", out.shape)

Number of segments: 6614
First 3 crop paths:
 ['/Users/whatislove/study/phd/data/processed/crops_v1_c1_500hz_30ch/Co_y6_016st_otherface_other_S1_0_26124_31124.npy', '/Users/whatislove/study/phd/data/processed/crops_v1_c1_500hz_30ch/Co_y6_016st_otherface_other_S1_0_31124_36123.npy', '/Users/whatislove/study/phd/data/processed/crops_v1_c1_500hz_30ch/Co_y6_016st_otherface_other_S1_0_36123_41123.npy']
EEG segments np array shape: (128, 30, 2500)
Input torch tensor shape: torch.Size([128, 30, 2500])
Model output shape: torch.Size([128, 2])
