In [57]:
import numpy as np
import mne
import scipy.io, scipy.interpolate
import plotly.express as px
import pandas as pd

# mne.utils.set_config('MNE_USE_CUDA', 'true')

L_FREQ, H_FREQ = 40, 300 # Lower and upper filtration bounds
CHANNELS_NUM = 62        # Number of channels in ECoG data
WAVELET_NUM = 40         # Number of wavelets in the indicated frequency range, with which the convolution is performed
DOWNSAMPLE_FS = 100      # Desired sampling rate
time_delay_secs = 0.2    # Time delay hyperparameter

current_fs = DOWNSAMPLE_FS

def reshape_column_ecog_data(multichannel_signal: np.ndarray):
    return multichannel_signal.T # (time, features) -> (features, time)

def filter_ecog_data(multichannel_signal: np.ndarray, fs=1000, powerline_freq=50):
    """
    Harmonics removal and frequency filtering
    :param multichannel_signal: Initial multi-channel signal
    :param fs: Sampling rate
    :param powerline_freq: Grid frequency
    :return: Filtered signal
    """
    harmonics = np.array([i * powerline_freq for i in range(1, (fs // 2) // powerline_freq)])

    print("Starting...")
    signal_filtered = mne.filter.filter_data(multichannel_signal,
                                             fs, l_freq=L_FREQ, h_freq=H_FREQ)  # remove all frequencies between l and h
    print("Noise frequencies removed...")
    signal_removed_powerline_noise = mne.filter.notch_filter(signal_filtered,
                                                             fs, freqs=harmonics)  # remove powerline  noise
    print("Powerline noise removed...")
    
    return signal_removed_powerline_noise

def normalize(multichannel_signal: np.ndarray, return_values = None):
    """
    standardization and removal of the median  from each channel
    :param multichannel_signal: Multi-channel signal
    :param return_values: Whether to return standardization parameters. By default - no
    """
    print("Normalizing...")
    means = np.mean(multichannel_signal, axis=1, keepdims=True)
    stds = np.std(multichannel_signal, axis=1, keepdims=True)
    transformed_data = (multichannel_signal - means) / stds
    common_average = np.median(transformed_data, axis=0, keepdims=True)
    transformed_data = transformed_data - common_average
    if return_values:
        return transformed_data, (means, stds)
    print("Normalized...")
    return transformed_data

def compute_spectrogramms(multichannel_signal : np.ndarray, fs=1000, freqs=np.logspace(np.log10(L_FREQ), np.log10(H_FREQ), WAVELET_NUM),
                          output_type='power'):
    """
    Compute spectrogramms using wavelet transforms

    :param freqs: wavelet frequencies to uses
    :param fs: Sampling rate
    :return: Signal spectogramms in shape (channels, wavelets, time)
    """
    
    num_of_channels = len(multichannel_signal)

    print("Computing wavelets...")
    spectrogramms = mne.time_frequency.tfr_array_morlet(multichannel_signal.reshape(1, num_of_channels, -1), sfreq=fs,
                                                        freqs=freqs, output=output_type, verbose=10, n_jobs=6)[0]
    
    
    print("Wavelet spectrogramm computed...")
    
    return spectrogramms


def downsample_spectrogramms(spectrogramms: np.ndarray, cur_fs=1000, needed_hz=H_FREQ, new_fs = None):
    """
    Reducing the sampling rate of spectrograms
    :param spectrogramms: Original set of spectrograms
    :param cur_fs: Current sampling rate
    :param needed_hz: The maximum frequency that must be unambiguously preserved during compression
    :param new_fs: The required sampling rate (interchangeable with needed_hz)
    :return: Decimated signal
    """
    print("Downsampling spectrogramm...")
    if new_fs == None:
        new_fs = needed_hz * 2    
    downsampling_coef = cur_fs // new_fs
    assert downsampling_coef > 1
    downsampled_spectrogramm = spectrogramms[:, :, ::downsampling_coef]
    print("Spectrogramm downsampled...")
    return downsampled_spectrogramm


def normalize_spectrogramms_to_db(spectrogramms: np.ndarray, convert = False):
    """
    Optional conversion to db, not used in the final version
    """
    if convert:
        return np.log10(spectrogramms+1e-12)
    else:
        return spectrogramms


def interpolate_fingerflex(finger_flex, cur_fs=1000, true_fs=25, needed_hz=DOWNSAMPLE_FS, interp_type='cubic'):
    
    """
    Interpolation of the finger motion recording to match the new sampling rate
    :param finger_flex: Initial sequences with finger flexions data
    :param cur_fs: ECoG sampling rate
    :param true_fs: Actual finger motions recording sampling rate
    :param needed_hz: Required sampling rate
    :param interp_type: Type of interpolation. By default - cubic
    :return: Returns an interpolated set of finger motions with the desired sampling rate
    """
    
    print("Interpolating fingerflex...")
    downscaling_ratio = cur_fs // true_fs
    print("Computing true_fs values...")
    finger_flex_true_fs = finger_flex[:, ::downscaling_ratio]
    finger_flex_true_fs = np.c_[finger_flex_true_fs,
        finger_flex_true_fs.T[-1]]  # Add as the last value on the interpolation edge the last recorded
    # Because otherwise it is not clear how to interpolate the tail at the end

    upscaling_ratio = needed_hz // true_fs
    
    ts = np.asarray(range(finger_flex_true_fs.shape[1])) * upscaling_ratio
    
    print("Making funcs...")
    interpolated_finger_flex_funcs = [scipy.interpolate.interp1d(ts, finger_flex_true_fs_ch, kind=interp_type) for
                                     finger_flex_true_fs_ch in finger_flex_true_fs]
    ts_needed_hz = np.asarray(range(finger_flex_true_fs.shape[1] * upscaling_ratio)[
                              :-upscaling_ratio])  # Removing the extra added edge
    
    print("Interpolating with needed frequency")
    interpolated_finger_flex = np.array([[interpolated_finger_flex_func(t) for t in ts_needed_hz] for
                                         interpolated_finger_flex_func in interpolated_finger_flex_funcs])
    return interpolated_finger_flex


def crop_for_time_delay(finger_flex : np.ndarray, spectrogramms : np.ndarray, time_delay_sec : float, fs : int):
    """
    Taking into account the delay between brain waves and movements
    :param finger_flex: Finger flexions
    :param spectrogramms: Computed spectrogramms
    :param time_delay_sec: time delay hyperparameter
    :param fs: Sampling rate
    :return: Shifted series with a delay
    """

    time_delay = int(time_delay_sec*fs)

    # the first motions do not depend on available data
    finger_flex_cropped = finger_flex[..., time_delay:] 
    # The latter spectrograms have no corresponding data
    spectrogramms_cropped = spectrogramms[..., :spectrogramms.shape[2]-time_delay]
    return finger_flex_cropped, spectrogramms_cropped


def visualize_signal(multichannel_signal: np.ndarray, channel_num: int, second_num: int, fs=DOWNSAMPLE_FS):
    """
    Function to visualize multi-channel signal section
    :param multichannel_signal: Multi-channel signal
    :param channel_num: Channel selected for visualization
    :param second_num: Selected record second
    :param fs: Sampling rate
    :return: -
    """
    df_channel = pd.DataFrame(data=np.asarray([np.asarray(range(fs)),
                                               multichannel_signal[channel_num][second_num*fs:second_num*fs+fs]]).T,
                              index=range(fs), columns=["t", "V"])

    fig = px.line(df_channel, x="t", y="V", title=f'channel_{channel_num}')
    fig.show()


In [40]:
"""
Loading the raw training data and applying the processing algorithm
"""

import pathlib

PATH = f"{pathlib.Path().resolve()}/data/pure_data/"

data = scipy.io.loadmat(f'{PATH}/sub1_comp.mat')

interpolated_finger_flex = interpolate_fingerflex(finger_flex=
                           reshape_column_ecog_data(data['train_dg'].astype('float64')))

db_spectrogramms = normalize_spectrogramms_to_db(spectrogramms=
                   downsample_spectrogramms(spectrogramms=
                   compute_spectrogramms(multichannel_signal=
                   filter_ecog_data(multichannel_signal=
                   normalize(multichannel_signal=
                   reshape_column_ecog_data(data['train_data'].astype('float64'))))), new_fs = DOWNSAMPLE_FS))


Interpolating fingerflex...
Computing true_fs values...
Making funcs...
Interpolating with needed frequency
Normalizing...
Normalized...
Starting...
Setting up band-pass filter from 40 - 3e+02 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 40.00
- Lower transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 35.00 Hz)
- Upper passband edge: 300.00 Hz
- Upper transition bandwidth: 75.00 Hz (-6 dB cutoff frequency: 337.50 Hz)
- Filter length: 331 samples (0.331 sec)

Noise frequencies removed...
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 H

[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=6)]: Done  12 tasks      | elapsed:   10.0s
[Parallel(n_jobs=6)]: Done  62 out of  62 | elapsed:   36.3s finished


Wavelet spectrogramm computed...
Downsampling spectrogramm...
Spectrogramm downsampled...


In [42]:
"""
Loading the raw validation data and applying the processing algorithm
"""

data_2 = scipy.io.loadmat(f'{PATH}/sub1_testlabels.mat')

interpolated_finger_flex_val = interpolate_fingerflex(finger_flex=
                                                      reshape_column_ecog_data(data_2['test_dg'].astype('float64')))


db_spectrogramms_val = normalize_spectrogramms_to_db(spectrogramms=
                       downsample_spectrogramms(spectrogramms=
                       compute_spectrogramms(multichannel_signal=
                       filter_ecog_data(multichannel_signal=
                       normalize(multichannel_signal=
                       reshape_column_ecog_data(data['test_data'].astype('float64'))))), new_fs=DOWNSAMPLE_FS))


Interpolating fingerflex...
Computing true_fs values...
Making funcs...
Interpolating with needed frequency
Normalizing...
Normalized...
Starting...
Setting up band-pass filter from 40 - 3e+02 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 40.00
- Lower transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 35.00 Hz)
- Upper passband edge: 300.00 Hz
- Upper transition bandwidth: 75.00 Hz (-6 dB cutoff frequency: 337.50 Hz)
- Filter length: 331 samples (0.331 sec)

Noise frequencies removed...
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 H

[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=6)]: Done  12 tasks      | elapsed:    4.1s
[Parallel(n_jobs=6)]: Done  62 out of  62 | elapsed:   17.4s finished


Wavelet spectrogramm computed...
Downsampling spectrogramm...
Spectrogramm downsampled...


In [43]:
"""
Taking time delay into account
"""


interpolated_finger_flex_cropped, db_spectrogramms_cropped = crop_for_time_delay(interpolated_finger_flex,
                                                                                db_spectrogramms, time_delay_secs,
                                                                                current_fs)
interpolated_finger_flex_val_cropped, db_spectrogramms_val_cropped = crop_for_time_delay(interpolated_finger_flex_val,
                                                                                db_spectrogramms_val, time_delay_secs,
                                                                                current_fs)

print(interpolated_finger_flex_cropped.shape)
print(db_spectrogramms_cropped.shape)
print(interpolated_finger_flex_val_cropped.shape)
print(db_spectrogramms_val_cropped.shape)

(5, 39980)
(62, 40, 39980)
(5, 19980)
(62, 40, 19980)


In [44]:
"""
Saving processed data
"""

import pathlib,os
SAVE_PATH = "/home/lomtev/brain_train/data/"
def save_proccessed_data(ecog_data, fingerflex_data, path, val=None, add_name = "", reshape = False):
    pathlib.Path(f"{path}/train").mkdir(parents=True, exist_ok=True)
    pathlib.Path(f"{path}/val").mkdir(parents=True, exist_ok=True)
    pathlib.Path(f"{path}/test").mkdir(parents=True, exist_ok=True)
    ecog_path = f"{path}/train/ecog_data{add_name}.npy" if val is None else f"{path}/val/ecog_data{add_name}.npy" if \
    val is True else f"{path}/test/ecog_data{add_name}.npy"
    fingerflex_path = f"{path}/train/fingerflex_data{add_name}.npy" if val is None else f"{path}/val/fingerflex_data{add_name}.npy" if \
        val is True else f"{path}/test/fingerflex_data{add_name}.npy"
    
    if reshape:
        ecog_data = ecog_data.reshape(CHANNELS_NUM*WAVELET_NUM,-1)
    
    os.remove(ecog_path)
    os.remove(fingerflex_path)
    np.save(ecog_path, ecog_data)
    np.save(fingerflex_path, fingerflex_data)
    

In [45]:
save_proccessed_data(db_spectrogramms_cropped, interpolated_finger_flex_cropped, SAVE_PATH, add_name = "")
save_proccessed_data(db_spectrogramms_val_cropped, interpolated_finger_flex_val_cropped, SAVE_PATH, val=True, add_name = "")

In [46]:
"""
Loading processed data
"""

def load_data(ecog_data_path, fingerflex_data_path):
    ecog_data = np.load(ecog_data_path)
    fingerflex_data = np.load(fingerflex_data_path)
    return ecog_data, fingerflex_data

ecog_data, fingerflex_data = load_data(f"{SAVE_PATH}/train/ecog_data.npy", f"{SAVE_PATH}/train/fingerflex_data.npy")

ecog_data_val, fingerflex_data_val = load_data(f"{SAVE_PATH}/val/ecog_data.npy", f"{SAVE_PATH}/val/fingerflex_data.npy")

In [47]:
fingerflex_data.shape

(5, 39980)

In [48]:
"""
Finger motions scaling
"""

from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler()
scaler.fit(fingerflex_data.T)

fingerflex_data_scaled = scaler.transform(fingerflex_data.T).T
fingerflex_data_val_scaled = scaler.transform(fingerflex_data_val.T).T

print(fingerflex_data_scaled.shape, fingerflex_data_val_scaled.shape)

(5, 39980) (5, 19980)


In [49]:
save_proccessed_data(ecog_data, fingerflex_data_scaled, SAVE_PATH, add_name = "")
save_proccessed_data(ecog_data_val, fingerflex_data_val_scaled, SAVE_PATH, val=True, add_name = "")

In [50]:
ecog_data.shape

(62, 40, 39980)

In [51]:
"""
ECoG data scaling
"""

from sklearn.preprocessing import RobustScaler

transformer = RobustScaler(unit_variance=True, quantile_range=(0.1, 0.9))
transformer.fit(ecog_data.T.reshape(-1,WAVELET_NUM*CHANNELS_NUM))

ecog_data_scaled = transformer.transform(ecog_data.T.reshape(-1,WAVELET_NUM*CHANNELS_NUM)).reshape(-1,\
                                                                                WAVELET_NUM, CHANNELS_NUM).T

ecog_data_val_scaled = transformer.transform(ecog_data_val.T.reshape(-1,WAVELET_NUM*CHANNELS_NUM)).reshape(-1,\
                                                                                WAVELET_NUM, CHANNELS_NUM).T

print(ecog_data_scaled.shape, ecog_data_val_scaled.shape)

(62, 40, 39980) (62, 40, 19980)


In [52]:
save_proccessed_data(ecog_data_scaled, fingerflex_data_scaled, SAVE_PATH, add_name = "")
save_proccessed_data(ecog_data_val_scaled, fingerflex_data_val_scaled, SAVE_PATH, val=True, add_name = "")