## Пакетная обработка данных ЭЭГ по пациентам

In [12]:
from scipy.io import loadmat
import time
import os, glob
import pandas as pd
import numpy as np
import sympy
import pickle
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

import polynomial as pln

Настройка базовых параметров и файлов данных

In [2]:
max_layer=10 # максимальное число слоёв в сети

retry_num = 7 # максимальное количество попыток рачсёта при одном значении критерия сходимости
error_goal = 0.005 # целевое значение среднеквадратиченой ошибки
conv_thres_b = 0.02 # базовое значение критерия сходимости
conv_thres_min = 0.01 # минимальное значение критерия сходимости

n_points = 100 # количество точек в обучающей выборке

path_to_files = "resources/subject_*.mat"
output_file = "resources/subjects_mx.pkl"
files = glob.glob(path_to_files)

Настройка формата данных ЭЭГ

In [3]:
electrodes = [
    'Fp1',
    'Fp2',
    'Fc5',
    'Fz',
    'Fc6',
    'T7',
    'Cz',
    'T8',
    'P7',
    'P3',
    'Pz',
    'P4',
    'P8',
    'O1',
    'Oz',
    'O2'
]
col_map = ['time']+electrodes
train_ids = electrodes

Цикл обработки

In [13]:
for file in tqdm(files, desc="subjects"):

    subj = loadmat(f"{file}")


    subj = pd.DataFrame({
        col: arr for col, arr in zip(col_map, np.transpose(subj['SIGNAL']))
    })


    trees = []

    norm_map = {el: subj[el].values.max()-subj[el].values.min() for el in electrodes}
    mean_map = {el: subj[el].mean() for el in electrodes}
    control_set = {el:(subj[el].values-mean_map[el])/norm_map[el] for el in electrodes}

    with tqdm(train_ids) as t_train_ids:
        for train_id in t_train_ids:
            t_train_ids.set_description(f"{file}:{train_id}")

            t = time.time()
            conv_thres = conv_thres_b
            error = error_goal+1
            try_n = 0
            ids = np.random.randint(0, subj[train_id].size, n_points)
            ids.sort()

            training_goal = control_set[train_id][ids]
            training_set = {el:(subj[el][ids]-mean_map[el])/norm_map[el] for el in electrodes if el != train_id}

            while error > error_goal and conv_thres >= conv_thres_min:
                t_train_ids.update()
                try_n +=1
                conv_thres = conv_thres_b*1/2**(try_n//retry_num)

                layer_num=1
                ntree = pln.Tree(inputs=np.array(list(training_set.keys())), restrictions=[0,1,0,1,0,1])
                ntree.regress(training_set, training_goal, conv_thres=conv_thres)
                nerror = error = ntree.error()
                while layer_num < max_layer and nerror <= error:
                    layer_num+=1
                    tree, error = ntree, nerror
                    ntree.new_layer(restrictions=[0,1,0,1,0,1], k=5)
                    ntree.regress(training_set, training_goal, conv_thres=conv_thres)
                    nerror = ntree.error()
                    t_train_ids.set_postfix(ordered_dict={
                        "iteration": try_n,
                        "mse": error,
                    })
                    t_train_ids.update()
            tree.prune(top_k=1)
            trees.append(tree)

    with open(f"{file}.pkl", "wb") as f:
        pickle.dump(trees, f)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [14]:
path_to_files = path_to_files+".pkl"
files = glob.glob(path_to_files)

trees_s = []
for file in files:
    with open(file, "rb") as f:
        subj = pickle.load(f)
        trees_s.append(subj)

Переформатируем полиномиальные деревья в вид матрицы линейных уравнений

In [15]:
mx = np.array(
    [
        [
            [tree.node_to_equation().diff(sympy.Symbol("№"+el)) for el in electrodes]
            +[tree.node_to_equation().func(*[term for term in tree.node_to_equation().args if not term.free_symbols])]
            for col, tree in zip(electrodes, trees)
        ] for trees in trees_s
    ] 
)
data = {
    'mx': mx,
    'key': [files, electrodes, electrodes+[1]]
}

In [16]:
with open(output_file, "wb") as f:
    pickle.dump(data, f)