# Импорты и конфиг

In [1]:
import sys

modules_to_reload = [
    "src.utils.eeg_controller",
]

for module in modules_to_reload:
    if module in sys.modules:
        del sys.modules[module]

%load_ext autoreload
%autoreload 2

import gc
import itertools

import matplotlib as mpl
import mne
import numpy as np
import tensorly as tl
import torch

from src.utils.metrics_calculators import IMetricCalculator

mpl.use("Agg")
%matplotlib notebook

In [2]:
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()

16

In [3]:
cache_dir = "../.cache/eeg"

# Скачивание

In [4]:
eeg_raw_list = {}
subjects = list(range(1, 5))  # 110
runs = list(range(3, 15))  # 15

for subject in subjects:
    eeg_raw_list[subject] = {}

for subject in subjects:
    data_path = mne.datasets.eegbci.load_data(subject=subject, runs=runs, path=cache_dir)
    raw_fnames = data_path

    for run_idx, f in enumerate(raw_fnames, start=3):
        raw = mne.io.read_raw_edf(f, preload=True)

        raw.rename_channels({ch: ch.replace(".", "").upper().replace("Z", "z").replace("FP1", "Fp1").replace("FP2", "Fp2") for ch in raw.ch_names})

        eeg_raw_list[subject][run_idx] = raw

Extracting EDF parameters from /home/johndoe_19/git-projects/tensor-methods-comparison/.cache/eeg/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /home/johndoe_19/git-projects/tensor-methods-comparison/.cache/eeg/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /home/johndoe_19/git-projects/tensor-methods-comparison/.cache/eeg/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R05.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /home/johndoe_19/git-projects/tensor-methods-comparison/.cache/eeg/MNE-eegbci-data/file

# Монтаж

In [5]:
montage_biosemi64 = mne.channels.make_standard_montage("biosemi64")

In [6]:
montage_copy = montage_biosemi64.copy()

In [7]:
coordsT9 = np.array([-0.08869014, -0.0, -0.04014873])
coordsT10 = np.array([0.08869014, 0.0, -0.04014873])

In [8]:
def edit_montage_dig(montage, old_name, new_name, new_coords=None):
    if old_name in montage.ch_names:
        idx = montage.ch_names.index(old_name)

        if new_coords is not None:
            montage.dig[idx + 3]["r"] = new_coords
            print(f"Координаты {old_name} изменены на {new_coords}")

        # Изменяем имя в списке каналов
        montage.ch_names[idx] = new_name
        print(f"Название электрода {old_name} изменено на {new_name}")
    else:
        print(f"Электрод {old_name} не найден в монтаже")


edit_montage_dig(montage_copy, "P9", "T9", coordsT9)
edit_montage_dig(montage_copy, "P10", "T10", coordsT10)
edit_montage_dig(montage_copy, "Fpz", "FPz")

Координаты P9 изменены на [-0.08869014 -0.         -0.04014873]
Название электрода P9 изменено на T9
Координаты P10 изменены на [ 0.08869014  0.         -0.04014873]
Название электрода P10 изменено на T10
Название электрода Fpz изменено на FPz


In [9]:
update_info = [
    {"name": "Fp1", "order": 22},
    {"name": "FPz", "order": 23},
    {"name": "Fp2", "order": 24},
    {"name": "AF7", "order": 25},
    {"name": "AF3", "order": 26},
    {"name": "AFz", "order": 27},
    {"name": "AF4", "order": 28},
    {"name": "AF8", "order": 29},
    {"name": "F7", "order": 30},
    {"name": "F5", "order": 31},
    {"name": "F3", "order": 32},
    {"name": "F1", "order": 33},
    {"name": "Fz", "order": 34},
    {"name": "F2", "order": 35},
    {"name": "F4", "order": 36},
    {"name": "F6", "order": 37},
    {"name": "F8", "order": 38},
    {"name": "FT7", "order": 39},
    {"name": "FC5", "order": 1},
    {"name": "FC3", "order": 2},
    {"name": "FC1", "order": 3},
    {"name": "FCz", "order": 4},
    {"name": "FC2", "order": 5},
    {"name": "FC4", "order": 6},
    {"name": "FC6", "order": 7},
    {"name": "FT8", "order": 8},
    {"name": "T9", "order": 43},
    {"name": "T7", "order": 41},
    {"name": "C5", "order": 8},
    {"name": "C3", "order": 9},
    {"name": "C1", "order": 10},
    {"name": "Cz", "order": 11},
    {"name": "C2", "order": 12},
    {"name": "C4", "order": 13},
    {"name": "C6", "order": 14},
    {"name": "T8", "order": 42},
    {"name": "T10", "order": 44},
    {"name": "TP7", "order": 45},
    {"name": "CP5", "order": 15},
    {"name": "CP3", "order": 16},
    {"name": "CP1", "order": 17},
    {"name": "CPz", "order": 18},
    {"name": "CP2", "order": 19},
    {"name": "CP4", "order": 20},
    {"name": "CP6", "order": 21},
    {"name": "TP8", "order": 46},
    {"name": "P7", "order": 47},
    {"name": "P5", "order": 48},
    {"name": "P3", "order": 49},
    {"name": "P1", "order": 50},
    {"name": "Pz", "order": 51},
    {"name": "P2", "order": 52},
    {"name": "P4", "order": 53},
    {"name": "P6", "order": 54},
    {"name": "P8", "order": 55},
    {"name": "PO7", "order": 56},
    {"name": "PO3", "order": 57},
    {"name": "POz", "order": 58},
    {"name": "PO4", "order": 59},
    {"name": "PO8", "order": 60},
    {"name": "O1", "order": 61},
    {"name": "Oz", "order": 62},
    {"name": "O2", "order": 63},
    {"name": "Iz", "order": 64},
]

In [10]:
def update_montage_points_with_offset(montage, update_info, offset=3):
    for item in update_info:
        name = item["name"]
        order = item["order"] - 1 + offset

        if order < 0 or order >= len(montage.dig):
            raise IndexError(f"Порядковый номер {order} выходит за пределы dig")

        if name not in montage.ch_names:
            montage.ch_names.append(name)
        else:
            idx = montage.ch_names.index(name)
            montage.ch_names[idx] = name

In [11]:
update_montage_points_with_offset(montage_copy, update_info)

In [12]:
for raw_list_by_runs in eeg_raw_list.values():
    for raw in raw_list_by_runs.values():
        raw.set_montage(montage_copy)

# Вывод топографий без декомпозиции

In [13]:
# save_dir = cache_dir / Path("topographies_2/before_decomposition/by_each_epochs/")
# save_dir.mkdir(parents=True, exist_ok=True)
#
# for subject_idx, runs_dict in eeg_raw_list.items():
#     for run_idx, raw in runs_dict.items():
#         annotations = raw.annotations
#
#         events, event_id = mne.events_from_annotations(raw)
#
#         tmin, tmax = -0.5, 1.5
#         epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), preload=True)
#
#         for event_name, event_code in event_id.items():
#             event_epochs = epochs[event_name]
#
#             for epoch_idx, epoch_data in enumerate(event_epochs.get_data(), start=1):
#                 data_mean = epoch_data.mean(axis=1)
#
#                 pos = np.array([ch["loc"][:2] for ch in event_epochs.info["chs"]])
#
#                 filename = save_dir / f"topography_subject-{subject_idx}_run-{run_idx}_event-{event_name}_epoch-{epoch_idx}.png"
#
#                 if filename.exists():
#                     print(f"File {filename} already exists. Skipping...")
#                     continue
#
#                 fig, ax = plt.subplots(figsize=(8, 8))
#                 mne.viz.plot_topomap(
#                     data_mean,
#                     pos,
#                     ch_type="eeg",
#                     names=event_epochs.ch_names,
#                     sensors=True,
#                     cmap="RdBu_r",
#                     contours=6,
#                     res=256,
#                     size=8,
#                     axes=ax,
#                     show=False,
#                 )
#
#                 plt.savefig(filename)
#                 plt.close("all")
#
#                 del fig, ax, data_mean, pos
#                 gc.collect()

# Преобразование в тензор

In [14]:
# Общий счетчик эпох
total_epochs = 0

# Для каждого субъекта и запуска
for subject_idx, runs_dict in eeg_raw_list.items():
    for run_idx, raw in runs_dict.items():
        # Извлекаем события
        events, event_id = mne.events_from_annotations(raw)

        # Создаем эпохи
        tmin, tmax = -0.5, 1.5
        epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), preload=True)

        # Считаем количество эпох
        num_epochs = len(epochs)
        print(f"Subject {subject_idx}, Run {run_idx}: {num_epochs} epochs")

        total_epochs += num_epochs

print(f"Общее количество эпох: {total_epochs}")

Used Annotations descriptions: ['T0', 'T1', 'T2']
Not setting metadata
30 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 30 events and 321 original time points ...
1 bad epochs dropped
Subject 1, Run 3: 29 epochs
Used Annotations descriptions: ['T0', 'T1', 'T2']
Not setting metadata
30 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 30 events and 321 original time points ...
1 bad epochs dropped
Subject 1, Run 4: 29 epochs
Used Annotations descriptions: ['T0', 'T1', 'T2']
Not setting metadata
30 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 30 events and 321 original time points ...
1 bad epochs dropped
Subject 1, Run 5: 29

In [15]:
all_subjects = list(eeg_raw_list.keys())
all_runs = list(next(iter(eeg_raw_list.values())).keys())
first_raw = next(iter(next(iter(eeg_raw_list.values())).values()))

n_channels = len(first_raw.info["ch_names"])
sfreq = first_raw.info["sfreq"]

events, event_id = mne.events_from_annotations(first_raw)
tmin, tmax = -0.5, 1.5
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), preload=True)

n_times = epochs.get_data().shape[2]

event_types = sorted({event for runs in eeg_raw_list.values() for raw in runs.values() for event in raw.annotations.description})
n_events = len(event_types)

eeg_data_tensor = []

event_to_index = {event: idx for idx, event in enumerate(event_types)}

for subject_idx, runs_dict in eeg_raw_list.items():
    subject_data = []

    for run_idx, raw in runs_dict.items():
        events, event_id = mne.events_from_annotations(raw)
        tmin, tmax = -0.5, 1.5
        epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), preload=True)

        run_data = np.zeros((n_events, len(epochs), n_channels, n_times))

        for event_name, event_code in event_id.items():
            if event_name in event_types:
                event_idx = event_to_index[event_name]
                event_epochs = epochs[event_name].get_data()

                run_data[event_idx, : event_epochs.shape[0], :, :] = event_epochs

        subject_data.append(run_data)

    eeg_data_tensor.append(subject_data)

eeg_data_tensor = np.array(eeg_data_tensor, dtype=np.float64)

Used Annotations descriptions: ['T0', 'T1', 'T2']
Not setting metadata
30 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 30 events and 321 original time points ...
1 bad epochs dropped
Used Annotations descriptions: ['T0', 'T1', 'T2']
Not setting metadata
30 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 30 events and 321 original time points ...
1 bad epochs dropped
Used Annotations descriptions: ['T0', 'T1', 'T2']
Not setting metadata
30 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 30 events and 321 original time points ...
1 bad epochs dropped
Used Annotations descriptions: ['T0', 'T1', 'T2']
Not setting metadata
30 ma

In [16]:
# (subject, run, event, epoch, channel, time)
print("Формат данных:", eeg_data_tensor.shape)

Формат данных: (4, 12, 3, 29, 64, 321)


# Сжатие декомпозицией

In [17]:
torch.cuda.synchronize()
torch.cuda.empty_cache()

gc.collect()

7153

In [18]:
rank = [
    eeg_data_tensor.shape[0],
    eeg_data_tensor.shape[1],
    eeg_data_tensor.shape[2],
    eeg_data_tensor.shape[3],
    eeg_data_tensor.shape[4],
    eeg_data_tensor.shape[5] // 2,
]

In [19]:
with tl.backend_context("pytorch"):
    tensor = tl.tensor(eeg_data_tensor).to("cuda")

    weight, factors = tl.decomposition.tucker(tensor, rank=rank, svd="truncated_svd", init="random", random_state=42)

    reconstructed_tensor = tl.tucker_to_tensor((weight, factors))

    frobenius_error = 100.0 * (tl.norm(reconstructed_tensor - tensor) / tl.norm(tensor)).item()

    original_size = IMetricCalculator.get_tensors_size(tensor)
    compressed_size = IMetricCalculator.get_tensors_size(weight, *factors)

    compression_ratio = 100.0 * compressed_size / original_size


print(f"Compression ratio: {compression_ratio}")
print(f"Frobenius error: {frobenius_error}")

del tensor, weight, factors

torch.cuda.synchronize()
torch.cuda.empty_cache()

gc.collect()

Compression ratio: 49.91005428214631
Frobenius error: 12.090861896184725


966

In [20]:
# (subject, run, event, epoch, channel, time)
print("Формат данных:", reconstructed_tensor.shape)

Формат данных: torch.Size([4, 12, 3, 29, 64, 321])


# Вывод топографий после декомпозиции

In [21]:
# save_dir = cache_dir / Path("topographies_2/after_decomposition/by_each_epochs/")
# save_dir.mkdir(parents=True, exist_ok=True)

In [22]:
# reconstructed_tensor = reconstructed_tensor.cpu().numpy().astype(np.float64)
#
# n_channels = 64
# sfreq = 160
# n_times = reconstructed_tensor.shape[-1]
#
# event_types = sorted({event for runs in eeg_raw_list.values() for raw in runs.values() for event in raw.annotations.description})
# n_events = len(event_types)
#
# event_to_index = {event: idx for idx, event in enumerate(event_types)}
#
# for subject_idx, subject_data in enumerate(reconstructed_tensor):
#     for run_idx, run_data in enumerate(subject_data):
#         for event_idx, event_name in enumerate(event_types):
#             event_epochs = run_data[event_idx]
#             for epoch_idx in range(event_epochs.shape[0]):
#                 if np.any(event_epochs[epoch_idx] != 0):
#                     data_mean = event_epochs[epoch_idx].mean(axis=1)
#
#                     pos = np.array([ch["loc"][:2] for ch in first_raw.info["chs"]])
#
#                     filename = save_dir / f"topography_subject-{subject_idx + 1}_run-{run_idx + 3}_event-{event_name}_epoch-{epoch_idx + 1}.png"
#
#                     if filename.exists():
#                         print(f"File {filename} already exists. Skipping...")
#                         continue
#
#                     fig, ax = plt.subplots(figsize=(8, 8))
#                     mne.viz.plot_topomap(
#                         data_mean,
#                         pos,
#                         ch_type="eeg",
#                         names=first_raw.info["ch_names"],
#                         sensors=True,
#                         cmap="RdBu_r",
#                         contours=6,
#                         res=256,
#                         size=8,
#                         axes=ax,
#                         show=False,
#                     )
#
#                     plt.savefig(filename)
#                     plt.close("all")
#
#                     del fig, ax, data_mean, pos
#                     gc.collect()

# Поиск оптимального ранга от frobenius error

In [23]:
torch.cuda.synchronize()
torch.cuda.empty_cache()

gc.collect()

1934

In [28]:
compression_coefficients = [1, 2, 4, 8, 16]

rank = [
    eeg_data_tensor.shape[0],
    eeg_data_tensor.shape[1],
    eeg_data_tensor.shape[2],
    eeg_data_tensor.shape[3],
    eeg_data_tensor.shape[4],
    eeg_data_tensor.shape[5],
]

print(rank)

[4, 12, 3, 29, 64, 321]


In [25]:
all_ranks = []

for coeffs in itertools.product(compression_coefficients, repeat=len(rank)):
    new_rank = [r // c for r, c in zip(rank, coeffs, strict=False)]
    if new_rank not in all_ranks:
        all_ranks.append(new_rank)

In [32]:
min_frobenius_error = float("inf")

tensor_compression_logs = []

In [33]:
with tl.backend_context("pytorch"):
    tensor = tl.tensor(eeg_data_tensor).to("cuda")

    for index_rank, current_rank in enumerate(all_ranks):
        try:
            weight, factors = tl.decomposition.tucker(tensor, rank=current_rank, svd="truncated_svd", init="random", random_state=42)

            reconstructed_tensor = tl.tucker_to_tensor((weight, factors))

            frobenius_error = 100.0 * (tl.norm(reconstructed_tensor - tensor) / tl.norm(tensor)).item()

            compression_ratio = 100.0 * IMetricCalculator.get_tensors_size(weight, *factors) / IMetricCalculator.get_tensors_size(tensor)

            tensor_compression_logs.append({"rank": current_rank, "frobenius_error": frobenius_error, "compression_ratio": compression_ratio})

            print(f"{index_rank} | frobenius error: {frobenius_error:.6f} % | compression ratio: {compression_ratio:.6f} % - {current_rank} / {rank}")

            if frobenius_error < min_frobenius_error:
                min_frobenius_error = frobenius_error

                min_frob_compression_ratio = 100.0 * IMetricCalculator.get_tensors_size(weight, *factors) / IMetricCalculator.get_tensors_size(tensor)

                best_rank = current_rank

            del weight, factors, reconstructed_tensor
            torch.cuda.empty_cache()
            gc.collect()

        except Exception as e:
            print(f"Error for rank {current_rank}: {e}")

    del tensor
    torch.cuda.empty_cache()
    gc.collect()

0 | frobenius error: 0.000000 % | compression error: 100.126058 % - [4, 12, 3, 29, 64, 321] / [4, 12, 3, 29, 64, 321]
1 | frobenius error: 12.090862 % | compression error: 49.910054 % - [4, 12, 3, 29, 64, 160] / [4, 12, 3, 29, 64, 321]
2 | frobenius error: 19.321309 % | compression error: 24.958003 % - [4, 12, 3, 29, 64, 80] / [4, 12, 3, 29, 64, 321]
3 | frobenius error: 26.784012 % | compression error: 12.481977 % - [4, 12, 3, 29, 64, 40] / [4, 12, 3, 29, 64, 321]
4 | frobenius error: 35.540939 % | compression error: 6.243964 % - [4, 12, 3, 29, 64, 20] / [4, 12, 3, 29, 64, 321]
5 | frobenius error: 12.003470 % | compression error: 50.123670 % - [4, 12, 3, 29, 32, 321] / [4, 12, 3, 29, 64, 321]
6 | frobenius error: 16.756704 % | compression error: 24.985549 % - [4, 12, 3, 29, 32, 160] / [4, 12, 3, 29, 64, 321]
7 | frobenius error: 22.282946 % | compression error: 12.494557 % - [4, 12, 3, 29, 32, 80] / [4, 12, 3, 29, 64, 321]
8 | frobenius error: 28.845885 % | compression error: 6.24906

  return getattr(


4895 | frobenius error: 96.779083 % | compression error: 0.251622 % - [1, 1, 3, 29, 4, 321] / [4, 12, 3, 29, 64, 321]
4896 | frobenius error: 96.793857 % | compression error: 0.126075 % - [1, 1, 3, 29, 4, 160] / [4, 12, 3, 29, 64, 321]
4897 | frobenius error: 96.781470 % | compression error: 0.063691 % - [1, 1, 3, 29, 4, 80] / [4, 12, 3, 29, 64, 321]
4898 | frobenius error: 96.804113 % | compression error: 0.032500 % - [1, 1, 3, 29, 4, 40] / [4, 12, 3, 29, 64, 321]
4899 | frobenius error: 96.904046 % | compression error: 0.016904 % - [1, 1, 3, 29, 4, 20] / [4, 12, 3, 29, 64, 321]
4900 | frobenius error: 96.685887 % | compression error: 1.131130 % - [1, 1, 3, 14, 64, 321] / [4, 12, 3, 29, 64, 321]
4901 | frobenius error: 96.709720 % | compression error: 0.566450 % - [1, 1, 3, 14, 64, 160] / [4, 12, 3, 29, 64, 321]
4902 | frobenius error: 96.692149 % | compression error: 0.285863 % - [1, 1, 3, 14, 64, 80] / [4, 12, 3, 29, 64, 321]
4903 | frobenius error: 96.729251 % | compression error: 

In [35]:
# (subject, run, event, epoch, channel, time)
print(f"Minimal Frobenius error, %: {min_frobenius_error:.6f} %")
print(f"Compression ratio, %: {min_frob_compression_ratio:.6f} %")
print(f"Best rank: {best_rank}")
print(f"Tensor shape: {eeg_data_tensor.shape}")

Minimal Frobenius error, %: 0.000000 %
Compression ratio, %: 48.401413 %
Best rank: [4, 12, 3, 14, 64, 321]
Tensor shape: (4, 12, 3, 29, 64, 321)


In [37]:
sorted_logs = sorted(tensor_compression_logs, key=lambda x: x["frobenius_error"])

for index_log, log in enumerate(sorted_logs):
    print(f"{index_log} | {log['frobenius_error']:.6f} % | {log['compression_ratio']:.6f} % | {log['rank']}")

0 | 0.000000 % | 48.401413 % | [4, 12, 3, 14, 64, 321]
1 | 0.000000 % | 100.126058 % | [4, 12, 3, 29, 64, 321]
2 | 12.003470 % | 50.123670 % | [4, 12, 3, 29, 32, 321]
3 | 12.003470 % | 24.261094 % | [4, 12, 3, 14, 32, 321]
4 | 12.090862 % | 49.910054 % | [4, 12, 3, 29, 64, 160]
5 | 12.090862 % | 24.128045 % | [4, 12, 3, 14, 64, 160]
6 | 16.756704 % | 12.094291 % | [4, 12, 3, 14, 32, 160]
7 | 16.756704 % | 24.985549 % | [4, 12, 3, 29, 32, 160]
8 | 18.690080 % | 25.122477 % | [4, 12, 3, 29, 16, 321]
9 | 18.690080 % | 12.190935 % | [4, 12, 3, 14, 16, 321]
10 | 19.321309 % | 24.958003 % | [4, 12, 3, 29, 64, 80]
11 | 19.321309 % | 12.066745 % | [4, 12, 3, 14, 64, 80]
12 | 21.425798 % | 6.077413 % | [4, 12, 3, 14, 16, 160]
13 | 21.425798 % | 12.523296 % | [4, 12, 3, 29, 16, 160]
14 | 22.282946 % | 6.048674 % | [4, 12, 3, 14, 32, 80]
15 | 22.282946 % | 12.494557 % | [4, 12, 3, 29, 32, 80]
16 | 23.675586 % | 12.621880 % | [4, 12, 3, 29, 8, 321]
17 | 23.675586 % | 6.155856 % | [4, 12, 3, 14, 8,