<a href="https://colab.research.google.com/github/Markvarte/EEG_analyzer/blob/master/eeg_analyzer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# path: /content/EEG_A.edf -> https://zenodo.org/records/160118
# additional: pip install mne, braindecode

In [7]:
import mne
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
from braindecode.models import EEGNetv4 as EEGNet

In [None]:
# window_size: зависит от частоты дискретизации
# частота записи ЕЕГ в 512 герц -> окно в 1 секунду
# частота записи ЕЕГ в 256 герц -> окно в 2 секунды
class EEGDataset(Dataset):
    def __init__(self, file_list, window_size=512, cut_step=256):
        self.samples = []
        for f in file_list:
            raw = mne.io.read_raw_edf(f, preload=True)
            data = raw.get_data()  # [channels, time]
            # режем на куски фиксированного размера
            # возможны несколько стратегий- без overlap, c overlap, c padding.
            # тут с overlap- шаг 256, окно 512
            # пример- [0:512], [256:512+256(768)], [512:768+256(1024)], ...
            for start in range(0, data.shape[1]-window_size, cut_step):
                segment = data[:, start:start+window_size]
                self.samples.append(segment.astype(np.float32))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x = self.samples[idx]
        # SimCLR-style аугментации
        def augment(sig):
            sig = sig + 0.01*np.random.randn(*sig.shape)       # шум
            sig = np.roll(sig, np.random.randint(-10, 10), -1) # сдвиг
            return sig
        return augment(x), augment(x)

In [None]:
# ====== Contrastive Loss ======
def contrastive_loss(z1, z2, temperature=0.5):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    representations = torch.cat([z1, z2], dim=0)
    similarity_matrix = torch.matmul(representations, representations.T)
    labels = torch.arange(z1.size(0)).repeat(2)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    mask = torch.eye(labels.shape[0], dtype=torch.bool)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
    negatives = similarity_matrix[~labels.bool()].view(labels.shape[0], -1)
    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(z1.device)
    logits = logits / temperature
    return F.cross_entropy(logits, labels)

# ====== Backbone (EEGNet) ======
backbone = EEGNet(
    n_chans=21,                 # число каналов
    n_outputs=128,              # эмбеддинги
    input_window_samples=512,
    final_conv_length='auto'
)

# убираем head → используем только фичи
backbone.classify = False
backbone = backbone.cuda()

optimizer = torch.optim.Adam(backbone.parameters(), lr=1e-3)

# ====== Обучение ======
dataset = EEGDataset(["file1.edf", "file2.edf", ...])
loader = DataLoader(dataset, batch_size=64, shuffle=True)

for epoch in range(10):
    for x1, x2 in loader:
        x1, x2 = x1.cuda(), x2.cuda()
        z1 = backbone(x1)  # эмбеддинги
        z2 = backbone(x2)
        loss = contrastive_loss(z1, z2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} | Loss {loss.item():.4f}")

torch.save(backbone.state_dict(), "backbone_ssl.pt")
