In [1]:
from pathlib import Path

import matplotlib.pyplot as plt
import mne
import numpy as np
import pandas as pd
import regex as re
import seaborn as sns
from scipy.signal import find_peaks

from src.data_loader import DataLoader
from src.model.measurement import KEEP_COLUMNS, Measurement

In [2]:
sns.set_theme()


In [3]:
FOURIER_SIZE = 1000

TRAIN_OUT_FLAT_PATH = Path(f"out/fft_train_{FOURIER_SIZE}.npy")
TRAIN_OUT_FULL_PATH = Path(f"out/fft_train_full_{FOURIER_SIZE}.npy")
EVAL_OUT_FLAT_PATH = Path(f"out/fft_eval_{FOURIER_SIZE}.npy")
EVAL_OUT_FULL_PATH = Path(f"out/fft_eval_full_{FOURIER_SIZE}.npy")

TRAIN_AGES_PATH = Path(f"out/fft_train_ages_{FOURIER_SIZE}.npy")
EVAL_AGES_PATH = Path(f"out/fft_eval_ages_{FOURIER_SIZE}.npy")


In [4]:
def get_raw(measurement: Measurement) -> tuple[pd.DataFrame, mne.io.RawArray]:
    data = measurement.data

    column_map = dict(
        zip(
            [col for col in data.columns if "EEG" in col],
            [
                re.findall(r"(?<=EEG\s)(.+)(?=\-REF)", column)[0].lower().capitalize()
                for column in data.columns
                if "EEG" in column
            ],
        )
    )

    to_remove = ["Ekg1", "T1", "T2", "IBI", "BURSTS", "SUPPR"]

    columns_formatted = []
    for column in data[KEEP_COLUMNS].columns:
        mapped = column_map.get(column, column)
        if mapped in to_remove:
            continue

        columns_formatted.append(mapped)

    mapped_data = data[KEEP_COLUMNS].rename(
        columns=dict(zip(data[KEEP_COLUMNS].columns, columns_formatted))
    )
    mapped_data = mapped_data[columns_formatted]

    info = mne.create_info(
        ch_names=list(columns_formatted),
        sfreq=250,
        ch_types=["eeg"] * len(columns_formatted),
        verbose=False,
    )

    raw = mne.io.RawArray(mapped_data.values.T, info, verbose=False)
    standard_montage = mne.channels.make_standard_montage("standard_1020")

    raw = raw.set_montage(standard_montage, verbose=False).filter(1, 90, verbose=False)

    # Filter 50 for European and 60 for American
    # 'iir', 'fir', 'fft', 'spectrum_fit'
    raw = raw.notch_filter(50, method="fft", verbose=False)  # method="spectrum_fit"
    raw = raw.notch_filter(60, method="fft", verbose=False)  # method="spectrum_fit"

    return mapped_data, raw


def plot_raw(raw: mne.io.RawArray, include_ica: bool = False):
    raw.compute_psd().plot()
    raw.compute_psd().plot(average=True)
    plt.show()

    if include_ica:
        ica = mne.preprocessing.ICA(n_components=14, random_state=789)
        ica.fit(raw.copy().filter(1, None, verbose=False), verbose=False)
        ica.plot_properties(raw)


In [5]:
patients = []
ages = np.array([])
batch_size = 100
no_batches = round(DataLoader.train_size / batch_size)
for batch_no, measurements in enumerate(DataLoader.get_train_iter(batch_size=100)):
    print(f"\rBatch: {batch_no + 1:2} / {no_batches}")

    for ind, measurement in enumerate(measurements):
        print(f"\r{ind + 1:4} / {len(measurements)}", end="")
        _, raw = get_raw(measurement)
        fourier = np.fft.fft(raw.get_data().T, n=FOURIER_SIZE, axis=0)

        patients.append(fourier)
        ages = np.append(ages, measurement.age)

patients = np.array(patients)


Batch:  1 / 12
Batch:  2 / 12
Batch:  3 / 12
Batch:  4 / 12
Batch:  5 / 12
Batch:  6 / 12
Batch:  7 / 12
Batch:  8 / 12
Batch:  9 / 12
Batch: 10 / 12
Batch: 11 / 12
Batch: 12 / 12
  71 / 71

In [6]:
patients.shape


(1171, 1000, 21)

In [7]:
# Change dtype of ages to int
ages = ages.astype(int)


In [8]:
ages.shape


(1171,)

In [9]:
np.save(TRAIN_AGES_PATH, ages)

In [10]:
np.abs(patients).shape

(1171, 1000, 21)

In [11]:
np.save(TRAIN_OUT_FULL_PATH, np.abs(patients))


In [12]:
patients_abs = np.abs(patients).reshape(patients.shape[0], -1)

In [13]:
patients_abs.shape

(1171, 21000)

In [14]:
res = np.array([np.hstack((row, age)) for row, age in zip(patients_abs, ages)])
np.save(TRAIN_OUT_FLAT_PATH, res)

In [15]:
patients_eval = []
ages_eval = np.array([])
measurements = DataLoader.get_eval()
for ind, measurement in enumerate(measurements):
    print(f"\r{ind + 1:4} / {len(measurements)}", end="")
    _, raw = get_raw(measurement)
    fourier = np.fft.fft(raw.get_data().T, n=FOURIER_SIZE, axis=0)

    patients_eval.append(fourier)
    ages_eval = np.append(ages_eval, measurement.age)

patients_eval = np.array(patients_eval)


 126 / 126

In [16]:
patients_eval.shape


(126, 1000, 21)

In [17]:
np.save(EVAL_OUT_FULL_PATH, np.abs(patients_eval))

In [18]:
patients_eval_abs = np.abs(patients_eval).reshape(patients_eval.shape[0], -1)


In [19]:
patients_eval_abs.shape

(126, 21000)

In [20]:
ages_eval = ages_eval.astype(int)


In [21]:
np.save(EVAL_AGES_PATH, ages_eval)

In [22]:
ages_eval


array([69, 80, 78, 42, 34, 60, 28, 28, 47, 20, 40, 47, 89, 40, 26, 30, 30,
       45, 80, 81, 59, 62, 37, 67, 67, 55, 22, 48, 44, 26, 47, 42, 54, 53,
       39, 54, 39, 38, 51, 50, 31, 37, 23, 62, 52, 44, 88, 62, 24, 49, 50,
       38, 41, 39, 24, 63, 30, 47, 64, 35, 32, 19, 25, 44, 25, 65, 68, 43,
       46, 18, 19, 31, 58, 31, 33, 54, 74, 23, 62, 25, 58, 53, 38, 28, 49,
       24, 59, 37, 68, 26, 83, 32, 66, 35, 35, 47, 26, 43, 41, 70, 19, 23,
       25, 34, 65, 79, 56, 58, 31, 48, 58, 43, 61, 29, 48, 71, 62, 43, 36,
       36, 55, 52, 81, 22, 21, 66])

In [23]:
res = np.array([np.hstack((row, age)) for row, age in zip(patients_eval_abs, ages_eval)])
np.save(EVAL_OUT_FLAT_PATH, res)