## Signal Analysis
This script is aimed to analyze the signals proposed for BCI usage for the MACI proposal.

In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
import pandas as pd
import seaborn as sns
import scipy.signal as sgn
import scipy.interpolate as sip

from Signal import Signal
from Dataset import Dataset

In [None]:
PATH = './stim_23'
PATH_RAW = os.path.join(PATH, 'train_data')
PATH_FVS = os.path.join(PATH, 'fv_data')
if not os.path.exists(PATH_FVS):
    os.mkdir(PATH_FVS)
PATH_RESULT = os.path.join(PATH, 'results')
if not os.path.exists(PATH_RESULT):
    os.mkdir(PATH_RESULT)

SF = 200

ORDER = 4
BP_LO = 4
BP_HI = 50
NOTCH = 50

WINDOW = 512
STRIDE = 1

FREQ = None #freq none takes it from the file name
BW = 1
HARMS = [1, 2]
APPLY_SNR = True

LAB_RELS = {
    99: 0,
    1: 1,
    2: 2
}

SOURCE = 'freq'

CHANNELS = None
SUBJECTS = None
SESSIONS = None
LABELS = None

freq_bin_values = list(np.arange(22, 24.1, 0.5)) + list(np.arange(45, 47.1, 0.5))

In [None]:
def pairplot_df(X, y, constant, df, clean = False):
    stim_moment = 0 #position in the stimulation
    stim_number = 0 #number of stimulation
    
    columns = df.columns
    responses = []
    for i, (feat, lab) in enumerate(zip(X, y)):
        if lab[1] == 1.0:
            max_peak_bin = feat.argmax()
            max_peak_value = np.round(float(feat[max_peak_bin]), 1)
            max_peak_channel = int(max_peak_bin // 10) #which channel has highest peak
            max_peak_freq = int(max_peak_bin % 10) #which frequency has the highest peak independent of channel
            if clean and any(
                df[columns[0]] == constant &
                df[columns[1]] == stim_number &
                df[columns[2]] == stim_moment &
                df[columns[3]] == max_peak_bin &
                abs(df[columns[4]] - max_peak_value) <= 0.1 &
                df[columns[5]] == max_peak_channel &
                df[columns[6]] == max_peak_freq
            ):
                pass
            else:
                df = pd.concat((df, pd.DataFrame([[constant, stim_number, stim_moment, max_peak_bin, max_peak_value, max_peak_channel + 1, freq_bin_values[max_peak_freq]]], columns=df.columns)))
            stim_moment += 1
        elif stim_moment != 0:
            ch, ch_counts = np.unique(df[columns[5]][-stim_moment:], return_counts=True)
            prevalent_channel = ch[ch_counts.argmax()]
            fr, fr_counts = np.unique(df[columns[5]][-stim_moment:], return_counts=True)
            prevalent_freq = fr[fr_counts.argmax()]
            response_raw = X[i-stim_moment:i, prevalent_channel+prevalent_freq]
            response_raw_range = np.arange(response_raw.shape[0])
            response_soft = np.interp(response_raw_range, response_raw_range[::10], response_raw[::10])
            responses.append(response_soft)
            stim_moment = 0
            stim_number += 1
    return df, responses

In [None]:
def detect_peaks(s):
    return {
        'max': [100*s.argmax()/s.shape[0]], 
        'peaks': (100*sgn.find_peaks(s)[0]/s.shape[0]).tolist()
    }

## Individual Analysis
Analysis carried out user by user. While we analyze each one, we also generate the necessary files for the global analysis

In [None]:
each_user_current = None
each_user_list = []
user_list = []
for f in os.listdir(PATH_RAW):
    user = f.split(' ')[0]
    if user != each_user_current:
        each_user_current = user
        if len(each_user_list) > 0:
            user_list.append(each_user_list)
        each_user_list = [os.path.join(PATH_RAW, f)]
    else:
        each_user_list.append(os.path.join(PATH_RAW, f))
user_list.append(each_user_list)

global_users = {}
for i, user in enumerate(user_list):
    db = Dataset(user, sf = SF, order = ORDER, bp_lo = BP_LO, bp_hi = BP_HI, notch = NOTCH,
                window = WINDOW, stride = STRIDE, freq = FREQ, bw = BW, 
                harms = HARMS, apply_snr = APPLY_SNR, lab_rels = LAB_RELS
                )
    #here individual analysis
    dataframe = pd.DataFrame(columns=["session", "stim_number", "stim_moment", "max_peak_bin", "max_peak_energy", "max_peak_channel", "max_peak_freq"])
    for sess in range(10):
        try:
            X = db.get_fv(source = SOURCE, channels = CHANNELS, subjects = SUBJECTS, 
                       sessions = [sess], labels = LABELS
                      )
            y = db.get_onehot(subjects = SUBJECTS, 
                       sessions = [sess], labels = LABELS
                           )
            dataframe, responses = pairplot_df(X, y, sess, dataframe)
            #fig = plt.figure(figsize=(15,5)).suptitle(f'Prevalent channel for S{i} session {sess}')
            #sns.lineplot(responses)
            
        except ValueError:
            continue
    
    dataframe.pop('max_peak_bin') #for display this should not be necessary having channel and frequency
    fig, axs = plt.subplots(1, len(dataframe.columns)-2)
    fig.set_figheight(5)
    fig.set_figwidth(30)
    suptitle = fig.suptitle(f'Histograms for S{i}')
    for n, ax in enumerate(axs[:-1]):
        if dataframe.columns[n+3] != 'max_peak_energy':
            sns.histplot(dataframe, x=dataframe.columns[n+3], hue=dataframe.columns[0], ax=ax, binwidth=0.5, binrange=(min(dataframe[dataframe.columns[n+3]]), 1+max(dataframe[dataframe.columns[n+3]])))
        else:
            sns.histplot(dataframe, x=dataframe.columns[n+3], hue=dataframe.columns[0], ax=ax)
    peaks = []
    maxs = []
    for res in responses:
        peak = detect_peaks(res)
        peaks.extend(peak['peaks'])
        maxs.extend(peak['max'])
    sns.histplot(peaks, ax=axs[-1], color='tab:pink', label=("peaks"), kde = True)
    sns.scatterplot(y=[axs[-1].get_ylim()[1]//4]*len(maxs), x=maxs, ax=axs[-1], s=100, color='tab:brown', label="maxs")
    axs[-1].set_xlabel('stim_process_position')
    axs[-1].xaxis.set_major_formatter(PercentFormatter())
    
    plt.savefig(os.path.join(PATH_RESULT, suptitle.get_text()))
#     fig = plt.figure().suptitle(f'Pairplot for S{i}')
#     sns.pairplot(dataframe, hue=dataframe.columns[0])
    del X, y, dataframe
    
    #individual fvs save
    fv = db.get_fv(source = SOURCE, channels = CHANNELS, subjects = SUBJECTS, 
                   sessions = SESSIONS, labels = LABELS
                  )
    lab = db.get_onehot(subjects = SUBJECTS, 
                   sessions = SESSIONS, labels = LABELS
                       )
    meta = db.get_metadata()
    
    global_users[f'S{i}'] = {
        'X': os.path.join(PATH_FVS, f'X_S{i}.npy'),
        'y': os.path.join(PATH_FVS, f'y_S{i}.npy'),
        'meta': os.path.join(PATH_FVS, f'meta_S{i}.txt'),
    }
    np.save(global_users[f'S{i}']['X'], fv)
    np.save(global_users[f'S{i}']['y'], lab)
    with open(global_users[f'S{i}']['meta'], 'w') as f:
        f.write(str(meta))
    
    del fv, lab, meta

## Global Analysis
Analysis carried out user by user. While we analyze each one, we also generate the necessary files for the global analysis

In [None]:
dataframe = pd.DataFrame(columns=["user", "stim_number", "stim_moment", "max_peak_bin", "max_peak_energy", "max_peak_channel", "max_peak_freq"])
for s in global_users:
    X, y = np.load(global_users[s]['X']), np.load(global_users[s]['y'])
    dataframe, responses = pairplot_df(X, y, s, dataframe)
    #fig = plt.figure(figsize=(15,5)).suptitle(f'Prevalent channel for {s}')
    #sns.lineplot(responses)

dataframe.pop('max_peak_bin') #for display this should not be necessary having channel and frequency
fig, axs = plt.subplots(1, len(dataframe.columns)-2)
fig.set_figheight(5)
fig.set_figwidth(30)
suptitle = fig.suptitle(f'Histograms for global')
for n, ax in enumerate(axs[:-1]):
    if dataframe.columns[n+3] != 'max_peak_energy':
        sns.histplot(dataframe, x=dataframe.columns[n+3], hue=dataframe.columns[0], ax=ax, binwidth = 0.5)
    else:
        sns.histplot(dataframe, x=dataframe.columns[n+3], hue=dataframe.columns[0], ax=ax)
peaks = []
maxs = []
for res in responses:
    peak = detect_peaks(res)
    peaks.extend(peak['peaks'])
    maxs.extend(peak['max'])
sns.histplot(peaks, ax=axs[-1], color='tab:blue', label="peaks", kde = True)
sns.scatterplot(y=[axs[-1].get_ylim()[1]//4]*len(maxs), x=maxs, ax=axs[-1], s=100, color='tab:orange', label="maxs")
axs[-1].set_xlabel('peaks_stim_position')
axs[-1].xaxis.set_major_formatter(PercentFormatter())

plt.savefig(os.path.join(PATH_RESULT, suptitle.get_text()))
#fig = plt.figure().suptitle(f'Pairplot for global')
#sns.pairplot(dataframe, hue=dataframe.columns[0])