In [5]:
import glob
from collections import defaultdict

import pandas as pd
import mne
import numpy as np

from tqdm import tqdm
from pathlib import Path

In [6]:
eeg_files = glob.glob('/mnt/d/Study/PhD/Data/EEG/*/*.set')
df_eegs = pd.DataFrame(eeg_files, columns=['path'])

In [None]:
target_sfreq = 500  # Hz

# Sliding window params
segment_sec = 5.0           # window length in seconds (crop length)
overlap_sec = 2.0           # overlap in seconds
segment_margin_sec = 5.0      # clean margin from event start/end in seconds

In [None]:
#set_file = '/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_041_fon1.set'

In [8]:
set_file = '/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_003_fon1.set'

raw = mne.io.read_raw_eeglab(set_file, preload=True)
print("Original sfreq:", raw.info["sfreq"])

if raw.info["sfreq"] != target_sfreq:
    raw.resample(target_sfreq)
    print("Resampled sfreq:", raw.info["sfreq"])
else:
    print("Already at target sfreq:", raw.info["sfreq"])

Original sfreq: 1000.0
Resampled sfreq: 500.0


In [12]:
sfreq = raw.info["sfreq"]
data = raw.get_data()  # shape: (n_channels, n_samples)
n_channels, n_samples = data.shape
print("Data shape:", data.shape)

Data shape: (128, 299516)


In [15]:
n_samples / 500

599.032

In [11]:
ann = raw.annotations
ann.onset, ann.duration, ann.description

(array([ 36.292, 132.311, 218.231, 305.396, 399.016, 496.607]),
 array([0.001, 0.001, 0.001, 0.001, 0.001, 0.001]),
 array(['S   1', 'S   2', 'S   1', 'S   2', 'S   1', 'S   2'], dtype='<U5'))

In [16]:
from joblib import Parallel, delayed
import pandas as pd
from tqdm import tqdm

def extract_ann_info(path, target_sfreq=500):
    """Load a .set file, return dict with path, total duration (s), ann.onset, ann.description."""
    try:
        raw = mne.io.read_raw_eeglab(path, preload=True, verbose=False)
        if raw.info["sfreq"] != target_sfreq:
            raw.resample(target_sfreq, verbose=False)
        n_samples = raw.n_times
        sfreq = raw.info["sfreq"]
        duration_sec = n_samples / sfreq
        ann = raw.annotations
        # Cast onset and description to Python list for DataFrame compatibility
        onset_list = ann.onset.tolist()
        description_list = list(ann.description)
        return {
            'path': path,
            'duration_sec': duration_sec,
            'ann_onset': onset_list,
            'ann_description': description_list
        }
    except Exception as e:
        return {
            'path': path,
            'duration_sec': None,
            'ann_onset': None,
            'ann_description': None,
            'error': str(e)
        }

# Assuming df_eegs['path'] contains paths to EEG files
records = Parallel(n_jobs=-1, verbose=1)(
    delayed(extract_ann_info)(path, target_sfreq)
    for path in tqdm(df_eegs['path'], desc="Reading EEGs with ann info")
)

df_ann_info = pd.DataFrame(records)
# If you want to see only successful cases (filter out errors if present):
# df_ann_info = df_ann_info[df_ann_info["duration_sec"].notnull()].reset_index(drop=True)
df_ann_info.head()

Reading EEGs with ann info:   0%|          | 0/96 [00:00<?, ?it/s][Parallel(n_jobs=-1)]: Using backend LokyBackend with 32 concurrent workers.
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
Reading EEGs with ann info: 100%|██████████| 96/96 [01:11<00:00,  1.34it/s]
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
[Parallel(n_jobs=-1)]: Done  96 out of  96 | elapsed:  2.5min finished


Unnamed: 0,path,duration_sec,ann_onset,ann_description
0,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_003_fon1.set,599.032,"[36.292, 132.311, 218.231, 305.396, 399.016, 4...","[S 1, S 2, S 1, S 2, S 1, S 2]"
1,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_008_fon1.set,718.248,"[8.361, 127.894, 244.88, 360.293, 478.137, 598...","[S 1, S 2, S 1, S 2, S 1, S 2]"
2,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_009_fon1.set,756.598,"[11.787, 126.276, 230.197, 339.358, 425.819, 5...","[S 1, S 2, S 1, S 2, S 1, S 2]"
3,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_010_fon1.set,616.164,"[27.844, 141.183, 252.179, 351.057, 422.676, 5...","[S 1, S 2, S 1, S 2, S 1, S 2]"
4,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_013_Fon1.set,366.824,"[2.02, 98.647, 143.365, 205.839, 244.698, 316....","[S 1, S 2, S 1, S 2, S 1, S 2]"


In [18]:
df_ann_info[:10]

Unnamed: 0,path,duration_sec,ann_onset,ann_description
0,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_003_fon1.set,599.032,"[36.292, 132.311, 218.231, 305.396, 399.016, 4...","[S 1, S 2, S 1, S 2, S 1, S 2]"
1,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_008_fon1.set,718.248,"[8.361, 127.894, 244.88, 360.293, 478.137, 598...","[S 1, S 2, S 1, S 2, S 1, S 2]"
2,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_009_fon1.set,756.598,"[11.787, 126.276, 230.197, 339.358, 425.819, 5...","[S 1, S 2, S 1, S 2, S 1, S 2]"
3,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_010_fon1.set,616.164,"[27.844, 141.183, 252.179, 351.057, 422.676, 5...","[S 1, S 2, S 1, S 2, S 1, S 2]"
4,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_013_Fon1.set,366.824,"[2.02, 98.647, 143.365, 205.839, 244.698, 316....","[S 1, S 2, S 1, S 2, S 1, S 2]"
5,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_014_fon1.set,430.8,"[2.422, 114.116, 223.771, 283.804, 330.991, 37...","[S 1, S 2, S 1, S 2, S 1, S 2]"
6,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_016st_fon1...,724.268,"[7.869, 127.416, 247.423, 367.446, 485.62, 605...","[S 1, S 2, S 1, S 2, S 1, S 2]"
7,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_020_Fon1.set,658.736,"[2.762, 115.947, 222.669, 335.911, 436.163, 55...","[S 1, S 2, S 1, S 2, S 1, S 2]"
8,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_021_fon1.set,727.936,"[36.808, 150.913, 265.106, 382.931, 498.136, 6...","[S 1, S 2, S 1, S 2, S 1, S 2]"
9,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_022_fon1.set,680.956,"[5.644, 118.733, 225.986, 342.245, 451.938, 57...","[S 1, S 2, S 1, S 2, S 1, S 2]"


In [30]:
df_ann_info['all_segments'] = df_ann_info['ann_onset'].apply(lambda x: [float(seg_start) for seg_start in x] + [float(raw.duration)])
df_ann_info['segment_durations'] = df_ann_info['all_segments'].apply(lambda x: [round(float(x[i+1]) - float(x[i]), 2) for i in range(len(x)-1)])

# all durations > 50 and < 130
df_ann_info['good_eeg'] = df_ann_info['segment_durations'].apply(lambda x: all(50 < d < 130 for d in x))
df_ann_info['bad_segments'] = df_ann_info['segment_durations'].apply(lambda x: [d for d in x if d < 50 or d > 130])

In [29]:
df_ann_info['good_eeg'].sum()

np.int64(23)

In [58]:
df_ann_info[~df_ann_info.good_eeg].head(10)

Unnamed: 0,path,duration_sec,ann_onset,ann_description,all_segments,segment_durations,good_eeg,bad_segments
1,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_008_fon1.set,718.248,"[8.361, 127.894, 244.88, 360.293, 478.137, 598...","[S 1, S 2, S 1, S 2, S 1, S 2]","[8.361, 127.894, 244.88, 360.293, 478.137, 598...","[119.53, 116.99, 115.41, 117.84, 120.02, 0.87]",False,[0.87]
4,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_013_Fon1.set,366.824,"[2.02, 98.647, 143.365, 205.839, 244.698, 316....","[S 1, S 2, S 1, S 2, S 1, S 2]","[2.02, 98.647, 143.365, 205.839, 244.698, 316....","[96.63, 44.72, 62.47, 38.86, 71.58, 282.76]",False,"[44.72, 38.86, 282.76]"
5,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_014_fon1.set,430.8,"[2.422, 114.116, 223.771, 283.804, 330.991, 37...","[S 1, S 2, S 1, S 2, S 1, S 2]","[2.422, 114.116, 223.771, 283.804, 330.991, 37...","[111.69, 109.65, 60.03, 47.19, 39.23, 228.81]",False,"[47.19, 39.23, 228.81]"
6,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_016st_fon1...,724.268,"[7.869, 127.416, 247.423, 367.446, 485.62, 605...","[S 1, S 2, S 1, S 2, S 1, S 2]","[7.869, 127.416, 247.423, 367.446, 485.62, 605...","[119.55, 120.01, 120.02, 118.17, 120.02, -6.61]",False,[-6.61]
7,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_020_Fon1.set,658.736,"[2.762, 115.947, 222.669, 335.911, 436.163, 55...","[S 1, S 2, S 1, S 2, S 1, S 2]","[2.762, 115.947, 222.669, 335.911, 436.163, 55...","[113.19, 106.72, 113.24, 100.25, 118.83, 44.04]",False,[44.04]
8,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_021_fon1.set,727.936,"[36.808, 150.913, 265.106, 382.931, 498.136, 6...","[S 1, S 2, S 1, S 2, S 1, S 2]","[36.808, 150.913, 265.106, 382.931, 498.136, 6...","[114.11, 114.19, 117.82, 115.21, 119.92, -19.02]",False,[-19.02]
9,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_022_fon1.set,680.956,"[5.644, 118.733, 225.986, 342.245, 451.938, 57...","[S 1, S 2, S 1, S 2, S 1, S 2]","[5.644, 118.733, 225.986, 342.245, 451.938, 57...","[113.09, 107.25, 116.26, 109.69, 118.84, 28.26]",False,[28.26]
10,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_024_fon1.set,655.99,"[8.767, 122.223, 234.441, 350.376, 448.842, 56...","[S 1, S 2, S 1, S 2, S 1, S 2]","[8.767, 122.223, 234.441, 350.376, 448.842, 56...","[113.46, 112.22, 115.93, 98.47, 113.59, 36.6]",False,[36.6]
11,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_025_fon1.set,669.854,"[4.929, 113.431, 223.167, 332.783, 444.887, 55...","[S 1, S 2, S 1, S 2, S 1, S 2]","[4.929, 113.431, 223.167, 332.783, 444.887, 55...","[108.5, 109.74, 109.62, 112.1, 112.34, 41.81]",False,[41.81]
12,/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_026_fon1.set,754.796,"[0.0, 0.166, 48.114, 165.304, 282.819, 400.206...","[boundary, Record, S 1, S 2, S 1, S 2,...","[0.0, 0.166, 48.114, 165.304, 282.819, 400.206...","[0.17, 47.95, 117.19, 117.52, 117.39, 117.58, ...",False,"[0.17, 47.95, -155.61]"


In [53]:
df_ann_info.to_csv('df_ann_info.parquet', index=False)

In [59]:
# Откроем последнее ЭЭГ-файл из df_ann_info и посчитаем статистику по сегментам
# Определять состояние глаз будем по статистике ЭЭГ, а не по меткам S1/S2

last_row = df_ann_info.iloc[12]
eeg_path = last_row['path']

# Загрузим файл с помощью mne
raw = mne.io.read_raw_eeglab(eeg_path, preload=True, verbose=False)

# Сегменты по аннотациям
onsets = last_row['all_segments']
descriptions = last_row['ann_description']
segment_durations = last_row['segment_durations']

import numpy as np

# Для примера возьмем feature: мощность альфа-ритма (8-13 Гц) как признак закрытых глаз
def calc_alpha_power(data, sfreq):
    from scipy.signal import welch
    psds = []
    for ch_data in data:
        f, Pxx = welch(ch_data, fs=sfreq, nperseg=sfreq*2)
        # интеграл мощности в диапазоне 8-13 Гц
        alpha_power = np.trapz(Pxx[(f >= 8) & (f <= 13)], f[(f >= 8) & (f <= 13)])
        psds.append(alpha_power)
    return float(np.mean(psds))

print("Статистика и по предположениям о состоянии глаз по каждому сегменту:")

# Для сравнения будем считать состояние глаз по относительной мощности альфа-ритма
alpha_powers = []
summary_stats = []
for i in range(len(segment_durations)):
    start = onsets[i]
    stop = onsets[i + 1]
    dur = segment_durations[i]
    desc = descriptions[i] if i < len(descriptions) else None

    # Получим данные сегмента (секунды в sample-индексы)
    sfreq = raw.info['sfreq']
    start_sample = int(start * sfreq)
    stop_sample = int(stop * sfreq)
    data_seg = raw.get_data()[:, start_sample:stop_sample]

    if data_seg.size == 0:
        stats = {
            'segment': i,
            'desc': desc,
            'start_sec': start,
            'stop_sec': stop,
            'duration_sec': dur,
            'mean': None,
            'std': None,
            'min': None,
            'max': None,
            'alpha_power': None,
            'inferred_eye_state': None,
            'warning': 'zero-size segment'
        }
        print(stats)
        summary_stats.append(stats)
        alpha_powers.append(0)
        continue

    mean = float(np.mean(data_seg))
    std = float(np.std(data_seg))
    datamin = float(np.min(data_seg))
    datamax = float(np.max(data_seg))

    alpha_power = calc_alpha_power(data_seg, sfreq)
    alpha_powers.append(alpha_power)

    stats = {
        'segment': i,
        'desc': desc,
        'start_sec': start,
        'stop_sec': stop,
        'duration_sec': dur,
        'mean': mean,
        'std': std,
        'min': datamin,
        'max': datamax,
        'alpha_power': alpha_power
    }
    summary_stats.append(stats)

# Определим threshold по альфа мощности: если мощности по сегментам bimodal, возьмём медиану как условную границу
alpha_powers_np = np.array(alpha_powers)
nonzero_mask = alpha_powers_np > 0
median_alpha = np.median(alpha_powers_np[nonzero_mask]) if np.any(nonzero_mask) else 0

# Оценим глазные состояния purely по статистике альфа-ритма
for i, stats in enumerate(summary_stats):
    alpha_power = stats.get('alpha_power')
    if alpha_power is None:
        stats["inferred_eye_state"] = None
    else:
        # Предположение: при закрытых глазах альфа выше (выше медианы)
        # Это простое разделение, для реального случая нужен анализ распределения!
        if alpha_power > median_alpha:
            state = 'скорее всего ЗАКРЫТЫ глаза (альфа↑)'
        else:
            state = 'скорее всего ОТКРЫТЫ глаза (альфа↓)'
        stats['inferred_eye_state'] = state

    print(stats)



  raw = mne.io.read_raw_eeglab(eeg_path, preload=True, verbose=False)
  f, Pxx = welch(ch_data, fs=sfreq, nperseg=sfreq*2)
  alpha_power = np.trapz(Pxx[(f >= 8) & (f <= 13)], f[(f >= 8) & (f <= 13)])


Статистика и по предположениям о состоянии глаз по каждому сегменту:
{'segment': 0, 'desc': np.str_('boundary'), 'start_sec': 0.0, 'stop_sec': 0.166, 'duration_sec': 0.17, 'mean': 2.5782799632545592e-14, 'std': 5.953366703999823e-06, 'min': -3.135120964050293e-05, 'max': 3.518932723999023e-05, 'alpha_power': 0.0, 'inferred_eye_state': 'скорее всего ОТКРЫТЫ глаза (альфа↓)'}
{'segment': 1, 'desc': np.str_('Record'), 'start_sec': 0.166, 'stop_sec': 48.114, 'duration_sec': 47.95, 'mean': -1.5106221567665995e-15, 'std': 1.140825848518049e-05, 'min': -0.0007663497924804687, 'max': 0.00323926611328125, 'alpha_power': 8.683946859338694e-12, 'inferred_eye_state': 'скорее всего ОТКРЫТЫ глаза (альфа↓)'}
{'segment': 2, 'desc': np.str_('S   1'), 'start_sec': 48.114, 'stop_sec': 165.304, 'duration_sec': 117.19, 'mean': 1.0676188584037835e-15, 'std': 7.810198342902996e-06, 'min': -0.0001717032470703125, 'max': 0.0001567550354003906, 'alpha_power': 2.6193262560893432e-11, 'inferred_eye_state': 'скорее

In [34]:
dict(df_ann_info.iloc[-1])

{'path': '/mnt/d/Study/PhD/Data/EEG/own/Mor_y1_004_own_face.set',
 'duration_sec': np.float64(746.172),
 'ann_onset': [5.4095,
  6.2615,
  6.8955,
  19.159,
  137.659,
  260.059,
  378.359,
  498.859,
  618.359],
 'ann_description': [np.str_('boundary'),
  np.str_('boundary'),
  np.str_('boundary'),
  np.str_('S1'),
  np.str_('S2'),
  np.str_('S1'),
  np.str_('S2'),
  np.str_('S1'),
  np.str_('S2')],
 'all_segments': [5.4095,
  6.2615,
  6.8955,
  19.159,
  137.659,
  260.059,
  378.359,
  498.859,
  618.359,
  599.032],
 'segment_durations': [0.85,
  0.63,
  12.26,
  118.5,
  122.4,
  118.3,
  120.5,
  119.5,
  -19.33],
 'good_eeg': np.False_,
 'bad_segments': [0.85, 0.63, 12.26, -19.33]}

In [27]:
dict(df_ann_info.iloc[4])

{'path': '/mnt/d/Study/PhD/Data/EEG/fon/Co_y6_013_Fon1.set',
 'duration_sec': np.float64(366.824),
 'ann_onset': [2.02, 98.647, 143.365, 205.839, 244.698, 316.274],
 'ann_description': [np.str_('S   1'),
  np.str_('S   2'),
  np.str_('S   1'),
  np.str_('S   2'),
  np.str_('S   1'),
  np.str_('S   2')],
 'all_segments': [2.02, 98.647, 143.365, 205.839, 244.698, 316.274, 599.032],
 'segment_durations': [96.63, 44.72, 62.47, 38.86, 71.58, 282.76]}

In [None]:
from joblib import Parallel, delayed
import traceback

def load_eeg_file(path, target_sfreq=500):
    """Load and resample a single EEG file, return (data, path) or (None, path) on error."""
    try:
        raw = mne.io.read_raw_eeglab(path, preload=True, verbose=False)
        if raw.info["sfreq"] != target_sfreq:
            raw.resample(target_sfreq, verbose=False)
        data = raw.get_data()  # shape: (n_channels, n_samples)
        return (data, path)
    except Exception as e:
        print(f"\nERROR loading file: {path}")
        print(f"Error type: {type(e).__name__}")
        print(f"Error message: {str(e)}")
        return (None, path)  # Return None for data, but keep path for tracking

# Load all files in parallel (adjust n_jobs based on your CPU cores)
results = Parallel(n_jobs=-1, verbose=1)(
    delayed(load_eeg_file)(path, target_sfreq) 
    for path in tqdm(df_eegs['path'], desc="Loading EEGs")
)

# Separate successful loads from errors
array_eegs = [data for data, path in results if data is not None]
error_paths = [path for data, path in results if data is None]

print(f"\nSuccessfully loaded: {len(array_eegs)}/{len(results)} files")
if error_paths:
    print(f"Failed to load {len(error_paths)} files:")
    for path in error_paths:
        print(f"  - {path}")

Loading EEGs:   0%|          | 0/96 [00:00<?, ?it/s][Parallel(n_jobs=-1)]: Using backend LokyBackend with 32 concurrent workers.
  warn(
  warn(
  warn(
  warn(
Exception ignored in: <_io.BytesIO object at 0x773d5c53df30>
Traceback (most recent call last):
  File "/home/whatislove/miniconda3/envs/phd/lib/python3.13/site-packages/joblib/externals/loky/process_executor.py", line 110, in _get_memory_usage
    gc.collect()
BufferError: Existing exports of data: object cannot be re-sized
Loading EEGs:  33%|███▎      | 32/96 [00:19<00:00, 206.85it/s]

In [15]:
len(array_eegs)

96

In [12]:
ann = raw.annotations
ann.onset, ann.duration, ann.description

segmen_starts = [float(seg_start) for seg_start in ann.onset] + [float(raw.duration)]

In [13]:
intervals = [(float(segmen_starts[i]) + segment_margin_sec, float(segmen_starts[i+1]) - segment_margin_sec) for i in range(len(segmen_starts) - 1)] 

In [1]:
events = [0, 1, 0, 1, 0, 1]

In [15]:
eeg_crops = defaultdict(list)

for event, interval in zip(events, intervals):
    interval_segments = []
    interval_start = interval[0]
    interval_end = interval[1]

    while interval_start < interval_end:
        interval_segments.append((int(interval_start * target_sfreq), int((interval_start + segment_sec) * target_sfreq)))
        interval_start += segment_sec - overlap_sec
    eeg_crops[event].extend(interval_segments)

In [18]:
85 * 2 * 30 * 3

15300

In [19]:
# crop numpy array from raw EEG data
eeg_array = raw.get_data()

crop = eeg_array[:, eeg_crops[1][0][0]:eeg_crops[1][0][1]]

In [22]:
import torch

  import pynvml  # type: ignore[import]


In [27]:
batch_size = 50
# Take batch_size different crops from eeg_array based on eeg_crops[1] and move to cuda
crop_batch = np.stack([
    eeg_array[:, eeg_crops[1][i][0]:eeg_crops[1][i][1]] 
    for i in range(batch_size)
])
crop_batch = torch.from_numpy(crop_batch).to('cuda')

In [29]:
dataset = [batch_size] * 100

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EEGNet(nn.Module):
    def __init__(self, n_channels=128, in_samples=2500, n_classes=2):
        super(EEGNet, self).__init__()
        # First temporal convolution
        self.conv1 = nn.Conv1d(n_channels, 64, kernel_size=7, stride=1, padding=3)
        self.bn1 = nn.BatchNorm1d(64)
        # Depthwise convolution
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2, groups=64)
        self.bn2 = nn.BatchNorm1d(128)
        # Pointwise convolution
        self.conv3 = nn.Conv1d(128, 256, kernel_size=1)
        self.bn3 = nn.BatchNorm1d(256)
        # Temporal pooling
        self.pool = nn.AdaptiveAvgPool1d(32)
        self.dropout = nn.Dropout(0.5)
        # Fully connected classification head
        self.fc1 = nn.Linear(256 * 32, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        # x: (batch_size, n_channels, in_samples)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)  # flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x  # logits

# Usage example for (batch_size, 128, 2500) input
eeg_sample = torch.randn(batch_size, 128, 2500).to('cuda')
model = EEGNet().to('cuda')
out = model(eeg_sample)  # out shape: (batch_size, 2)


In [34]:
# import math

# class EEGTransformer(nn.Module):
#     def __init__(self, n_channels=128, in_samples=2500, n_classes=2, d_model=128, nhead=8, num_layers=4, dim_feedforward=256, dropout=0.3):
#         super(EEGTransformer, self).__init__()
#         self.pos_embedding = nn.Parameter(torch.zeros(1, n_channels, d_model))
#         self.input_proj = nn.Linear(in_samples, d_model)
#         encoder_layer = nn.TransformerEncoderLayer(
#             d_model=d_model,
#             nhead=nhead,
#             dim_feedforward=dim_feedforward,
#             dropout=dropout,
#             activation='gelu',
#             batch_first=True)
#         self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
#         self.classifier = nn.Sequential(
#             nn.LayerNorm(d_model),
#             nn.Linear(d_model, n_classes)
#         )

#     def forward(self, x):
#         # x: (batch_size, n_channels, in_samples)
#         x = self.input_proj(x)                     # (batch_size, n_channels, d_model)
#         x = x + self.pos_embedding                 # positional encoding
#         out = self.transformer_encoder(x)          # (batch_size, n_channels, d_model)
#         out = out.mean(dim=1)                      # (batch_size, d_model) - global avg pooling channels
#         out = self.classifier(out)                 # (batch_size, n_classes)
#         return out  # logits

# # Usage example for (batch_size, 128, 2500) input
# eeg_sample = torch.randn(batch_size, 128, 2500).to('cuda')
# transf_model = EEGTransformer(n_channels=128, in_samples=2500, n_classes=2).to('cuda')
# out_transformer = transf_model(eeg_sample)  # out shape: (batch_size, 2)


In [90]:
import torch

  import pynvml  # type: ignore[import]


In [91]:
torch.cuda.is_available()

True

In [108]:
torch_crop = torch.from_numpy(crop).to(torch.float32)

In [111]:
torch_crop.to('cuda')

tensor([[ 1.3564e-06,  1.0416e-06,  1.9863e-07,  ..., -2.0719e-06,
         -2.5011e-06, -3.1162e-06],
        [-5.7126e-06, -5.4713e-06, -5.1856e-06,  ..., -1.4616e-06,
         -1.9963e-06, -2.9750e-06],
        [ 3.6175e-06,  3.5242e-06,  3.2241e-06,  ..., -2.1745e-06,
         -2.2885e-06, -2.6498e-06],
        ...,
        [ 4.5774e-06,  5.0745e-06,  5.7608e-06,  ...,  3.7391e-07,
          2.8602e-07,  5.0206e-07],
        [ 5.8502e-06,  6.0943e-06,  6.1615e-06,  ...,  8.5529e-07,
          7.4678e-07,  1.1209e-06],
        [-2.0557e-07,  6.2699e-07,  1.3594e-06,  ...,  4.2139e-07,
          3.4264e-07,  7.3738e-07]], device='cuda:0')

In [112]:
batch_size = 16
batch = torch.stack([torch_crop] * 16)

In [113]:
batch.to('cuda')

tensor([[[ 1.3564e-06,  1.0416e-06,  1.9863e-07,  ..., -2.0719e-06,
          -2.5011e-06, -3.1162e-06],
         [-5.7126e-06, -5.4713e-06, -5.1856e-06,  ..., -1.4616e-06,
          -1.9963e-06, -2.9750e-06],
         [ 3.6175e-06,  3.5242e-06,  3.2241e-06,  ..., -2.1745e-06,
          -2.2885e-06, -2.6498e-06],
         ...,
         [ 4.5774e-06,  5.0745e-06,  5.7608e-06,  ...,  3.7391e-07,
           2.8602e-07,  5.0206e-07],
         [ 5.8502e-06,  6.0943e-06,  6.1615e-06,  ...,  8.5529e-07,
           7.4678e-07,  1.1209e-06],
         [-2.0557e-07,  6.2699e-07,  1.3594e-06,  ...,  4.2139e-07,
           3.4264e-07,  7.3738e-07]],

        [[ 1.3564e-06,  1.0416e-06,  1.9863e-07,  ..., -2.0719e-06,
          -2.5011e-06, -3.1162e-06],
         [-5.7126e-06, -5.4713e-06, -5.1856e-06,  ..., -1.4616e-06,
          -1.9963e-06, -2.9750e-06],
         [ 3.6175e-06,  3.5242e-06,  3.2241e-06,  ..., -2.1745e-06,
          -2.2885e-06, -2.6498e-06],
         ...,
         [ 4.5774e-06,  5

In [None]:
# empty memory cuda

torch.cuda.empty_cache()

In [106]:
# remove tensor from cuda
torch.cuda.empty_cache()

In [82]:
299516 / 500

599.032

In [None]:
1736 

In [85]:
eeg_crops[1][-1]

(295803, 298303)

In [104]:
torch.cuda.empty_cache()

In [7]:
ann

NameError: name 'ann' is not defined

In [20]:
# ============================================================
# BUILD CLEAN EVENT INTERVALS (IN SAMPLES)
# ============================================================

intervals = []  # list of (start_sample, end_sample, label)

ann = raw.annotations
print("Number of annotations:", len(ann))

for onset, duration, desc in zip(ann.onset, ann.duration, ann.description):
    if desc not in label_map:
        continue  # ignore annotations we don't care about

    # Convert onset + margin to sample index using raw.time_as_index()
    # This handles any offset, meas_date, etc.
    clean_start_time = onset + L_sec
    clean_end_time   = onset - L_sec
    print(clean_start_time, clean_end_time)
    # If event is too short for the margin, skip it
    if clean_end_time <= clean_start_time:
        continue

    start_sample = raw.time_as_index(clean_start_time)[0]
    end_sample   = raw.time_as_index(clean_end_time)[0]

    # Clip to recording bounds
    start_sample = max(start_sample, 0)
    end_sample   = min(end_sample, n_samples - 1)

    if end_sample <= start_sample:
        continue

    label = label_map[desc]
    intervals.append((start_sample, end_sample, label))

print("Number of clean intervals:", len(intervals))

Number of annotations: 6
41.292 31.293
137.311 127.31200000000001
223.231 213.232
310.396 300.397
404.016 394.017
501.607 491.608
Number of clean intervals: 0


In [21]:


# OPTIONAL: sort by start_sample
intervals.sort(key=lambda x: x[0])


# ============================================================
# SLIDING WINDOWS OVER FULL RECORDING
# ============================================================

window_len_samples = int(round(N_sec * sfreq))
step_samples       = int(round(step_sec * sfreq))

print("window_len_samples:", window_len_samples)
print("step_samples:", step_samples)

start_indices = np.arange(0, n_samples - window_len_samples + 1, step_samples)
n_windows = len(start_indices)
print("Number of windows:", n_windows)

# Allocate arrays
X = np.empty((n_windows, n_channels, window_len_samples), dtype=np.float32)
y = np.zeros(n_windows, dtype=np.int64)  # 0 = background by default

def assign_label_for_window(w_start, w_end, intervals):
    """
    Return label for window [w_start, w_end) in samples.
    Rule:
      - If window is fully inside exactly one clean interval: return that label
      - If window is inside multiple intervals: use the first one
      - If window doesn't fit any clean interval: return 0 (background)
    """
    for (s, e, lab) in intervals:
        if w_start >= s and w_end <= e:
            return lab
    return 0

for i, w_start in enumerate(start_indices):
    w_end = w_start + window_len_samples

    # Extract data segment
    X[i] = data[:, w_start:w_end]

    # Assign label based on clean intervals
    y[i] = assign_label_for_window(w_start, w_end, intervals)

print("X shape:", X.shape)
print("y unique labels:", np.unique(y))


# ============================================================
# SAVE RESULT
# ============================================================

# Optional: time axis for each window (relative, seconds)
time_axis = np.arange(window_len_samples) / sfreq

np.savez(
    out_file,
    X=X,
    y=y,
    sfreq=sfreq,
    ch_names=np.array(raw.ch_names),
    window_len_sec=N_sec,
    overlap_sec=M_sec,
    margin_sec=L_sec,
    time_axis=time_axis,
)

print("Saved to:", out_file)

Resampled sfreq: 500.0


In [22]:
# 3) Extract events from annotations
#    For EEGLAB imports, MNE usually stores markers as annotations.
events, mne_event_id = mne.events_from_annotations(raw)
print("Found event codes:", mne_event_id)

Used Annotations descriptions: [np.str_('S   1'), np.str_('S   2')]
Found event codes: {np.str_('S   1'): 1, np.str_('S   2'): 2}


In [24]:
print("Using event_id:", event_id)

# 4) Create epochs (crops around each event)
epochs = mne.Epochs(
    raw,
    events=events,
    event_id=event_id,
    tmin=tmin,
    tmax=tmax,
    baseline=baseline,
    preload=True,
)

Using event_id: {np.str_('S   1'): 1, np.str_('S   2'): 2}
Not setting metadata
6 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 6 events and 501 original time points ...
0 bad epochs dropped


In [25]:
print("Epochs data shape:", epochs.get_data().shape)  # (n_epochs, n_channels, n_times)

# 5) Extract NN-ready arrays
# X: (n_epochs, n_channels, n_times)
X = epochs.get_data().astype(np.float32)

Epochs data shape: (6, 128, 501)


In [31]:
X[0].shape

(128, 501)

In [26]:
# y_raw: event code for each epoch (integers from event_id)
y_raw = epochs.events[:, 2]

In [None]:
# Map event codes to 0..K-1 (more convenient for NN)
unique_codes = np.sort(np.unique(y_raw))
code_to_idx = {code: idx for idx, code in enumerate(unique_codes)}
y = np.array([code_to_idx[c] for c in y_raw], dtype=np.int64)

print("X shape:", X.shape)
print("y shape:", y.shape)
print("Unique labels (0..K-1):", np.unique(y))

# 6) Save for later use in your NN
np.savez(
    out_file,
    X=X,
    y=y,
    sfreq=epochs.info["sfreq"],
    ch_names=np.array(epochs.ch_names),
    tmin=tmin,
    tmax=tmax,
)
print("Saved:", out_file)