# SSL pretext-task: reconstruction

**Что я хочу получить в конце**

После SSL-предобучения должна получить:

- Предобученный энкодер (часть U-Net без декодера) — encoder_state_dict.pt.

- Эмбеддинги для всех SSL-эпох
Например, массив Z_ssl_small.npy формы примерно (n_epochs, 1024, L_bottleneck).

## Загрузка и подготовка данных

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# ==== Загрузка маленького датасета ==== #

import numpy as np
import pandas as pd

# Путь к подготовленной маленькой выборке на диске
data_path = '/content/drive/MyDrive/MIPT/Diploma_P300/Data/ssl_data_50k.npz'

data = np.load(data_path)
X_ssl_ready = data["X"].astype(np.float32)   # (50000, 14, 208)
ssl_mean    = data["mean"].astype(np.float32)
ssl_std     = data["std"].astype(np.float32)

print(X_ssl_ready.shape, X_ssl_ready.dtype)


(50000, 14, 208) float32


In [3]:
# ==== Делим выборку на train / val ==== #

N = X_ssl_ready.shape[0]
rng = np.random.default_rng(42)
idx = np.arange(N)
rng.shuffle(idx)

val_ratio = 0.1
n_val = int(N * val_ratio)

idx_val = idx[:n_val]
idx_train = idx[n_val:]

X_train = X_ssl_ready[idx_train]
X_val   = X_ssl_ready[idx_val]

print("Train:", X_train.shape, "Val:", X_val.shape)


Train: (45000, 14, 208) Val: (5000, 14, 208)


In [4]:
# Настраиваем PyTorch и device
import torch
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [5]:
# ==== Dataset для reconstruction (x → x) ==== #

class EEGReconstructionDataset(Dataset):
    def __init__(self, epochs_array: np.ndarray):
        """
        epochs_array: (N, C, L), float32
        """
        assert epochs_array.ndim == 3, "Ожидаю массив формы (N, C, L)"
        # сохраняем сразу как torch.Tensor
        self.data = torch.from_numpy(epochs_array)   # (N, 14, 208)

    def __len__(self):
        # сколько всего эпох
        return self.data.shape[0]

    def __getitem__(self, idx):
        # одна эпоха (14, 208)
        x = self.data[idx]
        # для reconstruction цель = сам сигнал
        return x, x


In [6]:
train_dataset = EEGReconstructionDataset(X_train)
val_dataset   = EEGReconstructionDataset(X_val)

print("Размер тренировочного датасета:", len(train_dataset))
print("Размер валидационного датасета:", len(val_dataset))



Размер тренировочного датасета: 45000
Размер валидационного датасета: 5000


In [7]:
# ==== DataLoader ==== #

if device.type == "cuda":
    BATCH_SIZE = 64        # наверно можно больше
    NUM_WORKERS = 0
    PIN_MEMORY = True
else:
    BATCH_SIZE = 16        # на CPU поменьше
    NUM_WORKERS = 0        # чтобы не было багов
    PIN_MEMORY = False

print("batch_size:", BATCH_SIZE,
      "| num_workers:", NUM_WORKERS,
      "| pin_memory:", PIN_MEMORY)


batch_size: 64 | num_workers: 0 | pin_memory: True


In [8]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
)

In [9]:
# Проверка

batch_x, batch_y = next(iter(train_loader))
print("batch_x:", batch_x.shape, batch_x.dtype)
print("batch_y:", batch_y.shape, batch_y.dtype)


batch_x: torch.Size([64, 14, 208]) torch.float32
batch_y: torch.Size([64, 14, 208]) torch.float32


## Модель SSL: 1D U-Net

In [10]:
import torch.nn as nn

N, C, L = X_ssl_ready.shape
print("SSL data:", X_ssl_ready.shape)

# гиперпараметры
CFG = {
    "in_channels": C,
    "seq_len": L,
    "base_ch": 32,        # базовое число каналов (легче, чем 64)
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "epochs": 10,         # для начала; можно потом увеличить
}

SSL data: (50000, 14, 208)


In [11]:
class DoubleConv1D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class Down1D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.MaxPool1d(kernel_size=2, stride=2),
            DoubleConv1D(in_channels, out_channels)
        )

    def forward(self, x):
        return self.block(x)


class Up1D(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        self.bilinear = bilinear

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
            mid_channels = in_channels // 2
            self.conv = DoubleConv1D(in_channels, out_channels)
        else:
            self.up = nn.ConvTranspose1d(in_channels, out_channels, kernel_size=2, stride=2)
            self.conv = DoubleConv1D(in_channels, out_channels)

    def forward(self, x1, x2):
        # x1: низ, x2: skip
        x1 = self.up(x1)

        # подгоняем длину, если не совпадает
        diff = x2.size(-1) - x1.size(-1)
        if diff > 0:
            x1 = nn.functional.pad(x1, (diff // 2, diff - diff // 2))
        elif diff < 0:
            x2 = nn.functional.pad(x2, (-diff // 2, -diff - (-diff // 2)))

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv1D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet1D_Light(nn.Module):
    """
    U-Net 1D с 4 уровнями даунсемплинга, вдохновлён Hong et al., 2025.
    Каналы: 32 → 64 → 128 → 256, bottleneck 512.
    """
    def __init__(self, n_channels, n_classes, base_ch=32, bilinear=True):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        ch1 = base_ch
        ch2 = base_ch * 2
        ch3 = base_ch * 4
        ch4 = base_ch * 8
        bottleneck_ch = base_ch * 16  # 512 при base_ch=32

        # encoder
        self.inc = DoubleConv1D(n_channels, ch1)
        self.down1 = Down1D(ch1, ch2)
        self.down2 = Down1D(ch2, ch3)
        self.down3 = Down1D(ch3, ch4)
        self.down4 = Down1D(ch4, bottleneck_ch)

        # decoder
        self.up1 = Up1D(bottleneck_ch + ch4, ch4, bilinear)
        self.up2 = Up1D(ch4 + ch3, ch3, bilinear)
        self.up3 = Up1D(ch3 + ch2, ch2, bilinear)
        self.up4 = Up1D(ch2 + ch1, ch1, bilinear)
        self.outc = OutConv1D(ch1, n_classes)

    def encode(self, x):
        x1 = self.inc(x)     # (N, ch1, L)
        x2 = self.down1(x1)  # (N, ch2, L/2)
        x3 = self.down2(x2)  # (N, ch3, L/4)
        x4 = self.down3(x3)  # (N, ch4, L/8)
        x5 = self.down4(x4)  # (N, bottleneck_ch, L/16)
        return x1, x2, x3, x4, x5

    def decode(self, x1, x2, x3, x4, x5):
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def forward(self, x):
        x1, x2, x3, x4, x5 = self.encode(x)
        logits = self.decode(x1, x2, x3, x4, x5)
        return logits, x5  # x5 — bottleneck


In [12]:
# Инициализация
model = UNet1D_Light(
    n_channels=CFG["in_channels"],
    n_classes=CFG["in_channels"],
    base_ch=CFG["base_ch"],
).to(device)

print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")


2.62491 M parameters


In [13]:
# Loss и оптимизатор
mse_loss = nn.MSELoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=CFG["lr"],
    weight_decay=CFG["weight_decay"],
)


In [14]:
# Цикл обучения reconstruction

from tqdm.auto import tqdm
import copy

def train_reconstruction(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    max_epochs=30,
    patience=5,
):
    """
    criterion: MSELoss
    max_epochs: максимум эпох
    patience: сколько эпох ждать улучшения val_loss, прежде чем остановиться
    """
    model.to(device)
    best_val_loss = float("inf")
    best_weights = None
    epochs_no_improve = 0

    history = {
        "train_loss": [],
        "val_loss": [],
    }

    for epoch in range(1, max_epochs + 1):
        # ---- TRAIN ----
        model.train()
        running_loss = 0.0
        n_train = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{max_epochs} [train]", leave=False)
        for x, y in pbar:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            optimizer.zero_grad()
            y_pred, _ = model(x)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()

            bs = x.size(0)
            running_loss += loss.item() * bs
            n_train += bs
            pbar.set_postfix(loss=loss.item())

        train_loss = running_loss / n_train

        # ---- VALIDATION ----
        model.eval()
        val_running_loss = 0.0
        n_val = 0

        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)

                y_pred, _ = model(x)
                loss = criterion(y_pred, y)

                bs = x.size(0)
                val_running_loss += loss.item() * bs
                n_val += bs

        val_loss = val_running_loss / n_val

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)

        print(
            f"Epoch {epoch}/{max_epochs} | "
            f"train_loss={train_loss:.6f} | "
            f"val_loss={val_loss:.6f}"
        )

        # ---- EARLY STOPPING ----
        if val_loss < best_val_loss - 1e-6:
            best_val_loss = val_loss
            best_weights = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping: val_loss не улучшается {patience} эпох подряд.")
            break

    if best_weights is not None:
        model.load_state_dict(best_weights)

    return model, history


In [15]:
# Обучение
model, history = train_reconstruction(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=mse_loss,
    device=device,
    max_epochs=30,
    patience=5,
)

Epoch 1/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 1/30 | train_loss=0.067649 | val_loss=0.017307


Epoch 2/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 2/30 | train_loss=0.017919 | val_loss=0.022063


Epoch 3/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 3/30 | train_loss=0.013999 | val_loss=0.012405


Epoch 4/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 4/30 | train_loss=0.012278 | val_loss=0.015673


Epoch 5/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 5/30 | train_loss=0.010777 | val_loss=0.006091


Epoch 6/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 6/30 | train_loss=0.009657 | val_loss=0.014436


Epoch 7/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 7/30 | train_loss=0.009405 | val_loss=0.005299


Epoch 8/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 8/30 | train_loss=0.009003 | val_loss=0.005982


Epoch 9/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 9/30 | train_loss=0.008763 | val_loss=0.007778


Epoch 10/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 10/30 | train_loss=0.008420 | val_loss=0.007662


Epoch 11/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 11/30 | train_loss=0.008198 | val_loss=0.020500


Epoch 12/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 12/30 | train_loss=0.008015 | val_loss=0.004738


Epoch 13/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 13/30 | train_loss=0.007651 | val_loss=0.003871


Epoch 14/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 14/30 | train_loss=0.007687 | val_loss=0.017796


Epoch 15/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 15/30 | train_loss=0.006720 | val_loss=0.010494


Epoch 16/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 16/30 | train_loss=0.006650 | val_loss=0.003995


Epoch 17/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 17/30 | train_loss=0.006633 | val_loss=0.008119


Epoch 18/30 [train]:   0%|          | 0/704 [00:00<?, ?it/s]

Epoch 18/30 | train_loss=0.006376 | val_loss=0.007990
Early stopping: val_loss не улучшается 5 эпох подряд.


In [16]:
# Сохраняем модель

torch.save(model.state_dict(), "/content/drive/MyDrive/MIPT/Diploma_P300/SSL/unet_ssl_recon_50k.pth")


## Достаём энкодер

In [18]:
import torch
import torch.nn as nn

class UNet1DEncoder(nn.Module):
    """
    Обёртка над encoder-частью обученного U-Net.
    На вход:  x ∈ R^{B×C×L}
    На выход: bottleneck-фичи x5 ∈ R^{B×C_bottleneck×L_reduced}
    """
    def __init__(self, unet_model: nn.Module):
        super().__init__()
        # просто переиспользуем уже обученные блоки
        self.inc = unet_model.inc
        self.down1 = unet_model.down1
        self.down2 = unet_model.down2
        self.down3 = unet_model.down3
        self.down4 = unet_model.down4

    def forward(self, x):
        # это ровно то, что делал encode() в полном U-Net
        x1 = self.inc(x)    # (B, ch1, L)
        x2 = self.down1(x1) # (B, ch2, L/2)
        x3 = self.down2(x2) # (B, ch3, L/4)
        x4 = self.down3(x3) # (B, ch4, L/8)
        x5 = self.down4(x4) # (B, bottleneck_ch, L/16)
        return x5


In [19]:
# Создаём encoder из обученного model
encoder = UNet1DEncoder(model).to(device)
encoder.eval()  # для извлечения эмбеддингов


UNet1DEncoder(
  (inc): DoubleConv1D(
    (block): Sequential(
      (0): Conv1d(14, 32, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
      (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down1D(
    (block): Sequential(
      (0): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv1D(
        (block): Sequential(
          (0): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
          (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
          (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_runni

In [20]:
# Проверка формы выхода
x_batch, _ = next(iter(train_loader))
x_batch = x_batch.to(device)

with torch.no_grad():
    feats = encoder(x_batch)

print("Вход:", x_batch.shape)
print("Фичи:", feats.shape)


Вход: torch.Size([64, 14, 208])
Фичи: torch.Size([64, 512, 13])


In [21]:
# Сохраняем веса энкодера
import os
SAVE_DIR = "/content/drive/MyDrive/MIPT/Diploma_P300/SSL"

encoder_path = os.path.join(SAVE_DIR, "encoder_ssl_unet_50k.pth")
torch.save(encoder.state_dict(), encoder_path)
print("Encoder weights saved to:", encoder_path)


Encoder weights saved to: /content/drive/MyDrive/MIPT/Diploma_P300/SSL/encoder_ssl_unet_50k.pth


### Извлечение эмбеддингов

In [22]:
# ==== Dataset и DataLoader ==== #

class EEGDatasetAll(torch.utils.data.Dataset):
    def __init__(self, epochs_array: np.ndarray):
        self.data = torch.from_numpy(epochs_array.astype("float32"))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]      # (C, L)
        return x                # без y, нам не нужны метки для SSL

full_dataset = EEGDatasetAll(X_ssl_ready)

if device.type == "cuda":
    BATCH_SIZE_EMB = 256
    NUM_WORKERS_EMB = 0
    PIN_MEMORY_EMB = True
else:
    BATCH_SIZE_EMB = 64
    NUM_WORKERS_EMB = 0
    PIN_MEMORY_EMB = False

full_loader = DataLoader(
    full_dataset,
    batch_size=BATCH_SIZE_EMB,
    shuffle=False,          # порядок сохраняем
    num_workers=NUM_WORKERS_EMB,
    pin_memory=PIN_MEMORY_EMB,
)


In [23]:
# ==== Функция для извлечения эмбеддингов ==== #

@torch.no_grad() # не нужно вычислять градиент
def extract_embeddings(encoder, loader, device, agg="mean"):
    """
    encoder: UNet1DEncoder
    loader:  DataLoader по всем эпохам
    agg:     'mean' - усреднение по времени, 'flatten' - разворот (реже нужно)
    """
    encoder.to(device)
    encoder.eval()

    all_embeds = []

    for x in loader:
        x = x.to(device, non_blocking=True)    # (B, C, L)
        feats = encoder(x)                     # (B, C_bottleneck, L_red)

        if agg == "mean":
            emb = feats.mean(dim=-1)           # (B, C_bottleneck)
        elif agg == "flatten":
            emb = feats.flatten(start_dim=1)   # (B, C_bottleneck * L_red)
        else:
            raise ValueError("agg must be 'mean' or 'flatten'")

        all_embeds.append(emb.cpu().numpy())

    all_embeds = np.concatenate(all_embeds, axis=0)
    return all_embeds


In [24]:
embeddings = extract_embeddings(encoder, full_loader, device, agg="mean")
print("Embeddings shape:", embeddings.shape)


Embeddings shape: (50000, 512)


In [25]:
# ==== Сохранение эмбеддингов ==== #

emb_path = os.path.join(SAVE_DIR, "ssl_embeddings_512d_mean_50k.npz")
np.savez_compressed(emb_path, X=embeddings)
print("Embeddings saved to:", emb_path)


Embeddings saved to: /content/drive/MyDrive/MIPT/Diploma_P300/SSL/ssl_embeddings_512d_mean_50k.npz
