In [1]:
import matplotlib.pyplot as plt
import mne
import numpy as np
import pandas as pd
import regex as re

from src.data_loader import DataLoader
from src.model.measurement import KEEP_COLUMNS, Measurement

In [2]:
FOURIER_SIZE = 10000

In [3]:
def get_raw(measurement: Measurement) -> tuple[pd.DataFrame, mne.io.RawArray]:
    data = measurement.data

    column_map = dict(
        zip(
            [col for col in data.columns if "EEG" in col],
            [
                re.findall(r"(?<=EEG\s)(.+)(?=\-REF)", column)[0].lower().capitalize()
                for column in data.columns
                if "EEG" in column
            ],
        )
    )

    to_remove = ["Ekg1", "T1", "T2", "IBI", "BURSTS", "SUPPR"]

    columns_formatted = []
    for column in data[KEEP_COLUMNS].columns:
        mapped = column_map.get(column, column)
        if mapped in to_remove:
            continue

        columns_formatted.append(mapped)

    mapped_data = data[KEEP_COLUMNS].rename(columns=dict(zip(data[KEEP_COLUMNS].columns, columns_formatted)))
    mapped_data = mapped_data[columns_formatted]

    info = mne.create_info(
        ch_names=list(columns_formatted),
        sfreq=250,
        ch_types=["eeg"] * len(columns_formatted),
        verbose=False
    )

    raw = mne.io.RawArray(mapped_data.values.T, info, verbose=False)
    standard_montage = mne.channels.make_standard_montage("standard_1020")

    raw = raw.set_montage(standard_montage, verbose=False).filter(1, 90, verbose=False)

    # Filter 50 for European and 60 for American
    # 'iir', 'fir', 'fft', 'spectrum_fit'
    raw = raw.notch_filter(50, method='fft', verbose=False) # method="spectrum_fit"
    raw = raw.notch_filter(60, method='fft', verbose=False) # method="spectrum_fit"

    return mapped_data, raw


def plot_raw(raw: mne.io.RawArray, include_ica: bool = False):
    raw.compute_psd().plot()
    raw.compute_psd().plot(average=True)
    plt.show()

    if include_ica:
        ica = mne.preprocessing.ICA(n_components=14, random_state=789)
        ica.fit(raw.copy().filter(1, None, verbose=False), verbose=False)
        # ica.plot_components()
        ica.plot_properties(raw)

In [4]:

patients = []
ages = np.array([])  # 1D array
batch_size = 100
no_batches = round(DataLoader.train_size / batch_size)
for batch_no, measurements in enumerate(DataLoader.get_train_iter(batch_size=100)):
    print(f"\rBatch: {batch_no + 1:2} / {no_batches}")

    for ind, measurement in enumerate(measurements):
        print(f"\r{ind + 1:4} / {len(measurements)}", end="")
        _, raw = get_raw(measurement)
        fourier = np.fft.fft(raw.get_data().T, n=FOURIER_SIZE, axis=0)

        patients.append(fourier)
        ages = np.append(ages, measurement.age)

patients = np.array(patients)


Batch:  1 / 12
Batch:  2 / 12
Batch:  3 / 12
Batch:  4 / 12
Batch:  5 / 12
Batch:  6 / 12
Batch:  7 / 12
Batch:  8 / 12
Batch:  9 / 12
Batch: 10 / 12
Batch: 11 / 12
Batch: 12 / 12
  71 / 71

In [12]:
patients = patients.reshape(1171, 10000, 21)

AttributeError: 'tuple' object has no attribute 'reshape'

In [11]:
patients

(1171, 10000, 21)

In [6]:
# Change dtype of ages to int
ages = ages.astype(int)

In [13]:
ages.shape

(1171,)

In [None]:
# Get modulus of complex numbers from rows
rows_with_age = np.array([np.hstack((np.abs(row), age)) for row, age in zip(patients, ages)])

In [None]:
rows_eval = np.array([])
ages_eval = np.array([])
measurements = DataLoader.get_eval()
for ind, measurement in enumerate(measurements):
    print(f"\r{ind + 1:4} / {len(measurements)}", end="")
    _, raw = get_raw(measurement)
    fourier = np.fft.fft2(raw.get_data().T, s=(FOURIER_SIZE, 1))

    rows_eval = np.vstack((rows_eval, fourier)) if rows_eval.size else fourier
    ages_eval = np.append(ages_eval, measurement.age)

In [None]:
rows_eval.shape

In [None]:
rows_eval_up = rows_eval.reshape(-1, 10000)


In [None]:
ages_eval = ages_eval.astype(int)


In [None]:
ages_eval

In [None]:
rows_with_age_eval = np.hstack((np.abs(rows_eval_up), np.angle(rows_eval_up), ages_eval.reshape(-1, 1)))
pd.DataFrame(rows_with_age_eval).to_csv("out/fft_eval_full.csv", index=False)

In [None]:
pd.read_csv("out/fft_eval_full.csv")