The goal of this file is to transform the data in the original file and output a csv with this data that can be used to train the machine learning model

In [1]:
import numpy as np
import pandas as pd
import random

In [2]:
EEG_PATH = '../train_eegs/'
SPEC_PATH = '../train_spectrograms/'
CHANNELS = ['Fp1', 'F3', 'C3', 'P3', 'F7', 'T3', 'T5', 'O1', 'Fz', 'Cz', 'Pz', 'Fp2', 'F4', 'C4',
 'P4', 'F8', 'T4', 'T6', 'O2', 'EKG']
MIN_FREQUENCY = 0
MAX_FREQUENCY = 10
NUM_FREQUENCIES = (MAX_FREQUENCY - MIN_FREQUENCY) * 6
NUM_ROWS = 10000

In [3]:
# Function to generate an empty dataset with the correct headers
def create_empty_dataset():
    # Create empty new database with the appropriate headers
    headers = []
    for channel in CHANNELS:
        headers.append("var_" + channel)
        for j in range(NUM_FREQUENCIES):
            headers.append("amp_" + channel + "_" + str(j))
        for j in range(NUM_FREQUENCIES):
            headers.append("phase_" + channel + "_" + str(j))
    headers.append("label")

    print(headers)
    print(len(headers))

    new_dataframe = pd.DataFrame(columns = headers)

    return new_dataframe

In [4]:
# Function to generate the output values
def row_transform(row, min_frequency, max_frequency):
    """
    Takes as input a single sample from the train set and returns a reformatted row
    New format is: 
        For each cell, the (1) variance, (2) amplitudes, (3) phases
        The ground truth label
    """

    new_row = []

    # Read in the EEG
    eeg_full = pd.read_parquet(f'{EEG_PATH}{row.eeg_id}.parquet')

    # Get the middle 10s that the diagnosis comes from
    eeg_offset = int( row.eeg_label_offset_seconds )
    eeg_10s = eeg_full.iloc[(eeg_offset+20)*200:(eeg_offset+30)*200]

    if np.any(np.isnan(eeg_10s)):
        return []

    # Construct the interval
    t = np.linspace(0, 10, eeg_10s.size)

    # Iterate through all the channels
    for i in range(len(CHANNELS)):

        # Get the data for the specific channel
        signal = eeg_10s.iloc[:, i]

        # Perform DFT using numpy's rfft function
        dft_result = np.fft.rfft(signal)
        amplitude = np.abs(dft_result)
        phase = np.angle(dft_result)

        # Add to the row
        new_row.append(np.var(signal))
        new_row.extend(amplitude[6*min_frequency:6*max_frequency])
        new_row.extend(phase[6*min_frequency:6*max_frequency])

    # Add the ground truth
    #print(row.expert_consensus)
    #print(len(new_row))
    new_row.append(row.expert_consensus)

    #print(len(new_row))

    return new_row


In [5]:
# Read in the train dataset
train = pd.read_csv('../train.csv')

train_reformat = create_empty_dataset()

print(train_reformat)

rows_removed = 0

# Randomize if not processing full dataset
if NUM_ROWS != len(train):
    random_indices = random.sample(range(len(train)), NUM_ROWS)

    # Create a new DataFrame with the random subset
    train = train.iloc[random_indices]

    print("\nSelecting a random subset of data to avoid problems arising from correlation between samples in same session")
    print("New subset has dimensions: {}".format(train.shape))
    print(train.head())

# Iterate through all rows
print("Processing {} raw samples\n".format(NUM_ROWS))

for r in range(NUM_ROWS): # train.shape[0]
    new_row = row_transform(train.iloc[r], MIN_FREQUENCY, MAX_FREQUENCY)

    # Check if NaN issues
    if len(new_row) != 0:
        train_reformat.loc[len(train_reformat)] = new_row
    else:
        rows_removed += 1

    if r % 100 == 0:
        print("Done with " + str(r) + " iterations.")
        
print("\nSamples successfully processed: {}".format(len(train_reformat)))
print("Samples removed due to incomplete data: {}".format(rows_removed))

assert len(train_reformat) + rows_removed == NUM_ROWS, "Row mismatch"

# Convert DataFrame B to CSV
train_reformat.to_csv('train_{}_samples_{}_to_{}_hz.csv'.format(NUM_ROWS, MIN_FREQUENCY, MAX_FREQUENCY), index=False)

['var_Fp1', 'amp_Fp1_0', 'amp_Fp1_1', 'amp_Fp1_2', 'amp_Fp1_3', 'amp_Fp1_4', 'amp_Fp1_5', 'amp_Fp1_6', 'amp_Fp1_7', 'amp_Fp1_8', 'amp_Fp1_9', 'amp_Fp1_10', 'amp_Fp1_11', 'amp_Fp1_12', 'amp_Fp1_13', 'amp_Fp1_14', 'amp_Fp1_15', 'amp_Fp1_16', 'amp_Fp1_17', 'amp_Fp1_18', 'amp_Fp1_19', 'amp_Fp1_20', 'amp_Fp1_21', 'amp_Fp1_22', 'amp_Fp1_23', 'amp_Fp1_24', 'amp_Fp1_25', 'amp_Fp1_26', 'amp_Fp1_27', 'amp_Fp1_28', 'amp_Fp1_29', 'amp_Fp1_30', 'amp_Fp1_31', 'amp_Fp1_32', 'amp_Fp1_33', 'amp_Fp1_34', 'amp_Fp1_35', 'amp_Fp1_36', 'amp_Fp1_37', 'amp_Fp1_38', 'amp_Fp1_39', 'amp_Fp1_40', 'amp_Fp1_41', 'amp_Fp1_42', 'amp_Fp1_43', 'amp_Fp1_44', 'amp_Fp1_45', 'amp_Fp1_46', 'amp_Fp1_47', 'amp_Fp1_48', 'amp_Fp1_49', 'amp_Fp1_50', 'amp_Fp1_51', 'amp_Fp1_52', 'amp_Fp1_53', 'amp_Fp1_54', 'amp_Fp1_55', 'amp_Fp1_56', 'amp_Fp1_57', 'amp_Fp1_58', 'amp_Fp1_59', 'phase_Fp1_0', 'phase_Fp1_1', 'phase_Fp1_2', 'phase_Fp1_3', 'phase_Fp1_4', 'phase_Fp1_5', 'phase_Fp1_6', 'phase_Fp1_7', 'phase_Fp1_8', 'phase_Fp1_9', 'phase_F