In [None]:
# the folder train eegs has a lot of parquet files. Read each of them and store the results in a dataframe
import pandas as pd
import pyarrow.parquet as pq
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import polars as pl
import dask
dask.config.set({'dataframe.query-planning': True})
import dask.dataframe as dd
from dask.distributed import Client
from tqdm import tqdm

tqdm.pandas()

In [None]:
train = pd.read_csv('../data/raw/train.csv')

In [None]:
len(train)

In [None]:
# Check if the patient ids are the same for all unique eeg_ids
different_id_eeg_ids = train.groupby('eeg_id').progress_apply(lambda x: (x.loc[:, 'patient_id'].nunique() != 1))
different_id_eeg_ids = different_id_eeg_ids[different_id_eeg_ids].index.tolist()

In [None]:
print(f'There are {len(different_id_eeg_ids)} eeg_ids with different patient_ids')
print(f'The eeg_ids are {different_id_eeg_ids}')


In [None]:
# Check if the last 6 columns are the same for all unique eeg_ids
different_labels_eeg_ids = train.groupby('eeg_id').progress_apply(lambda x: (x.iloc[:, -6:].nunique() != 1).any())
different_labels_eeg_ids = different_labels_eeg_ids[different_labels_eeg_ids].index.tolist()

In [None]:
print(f'There are {len(different_labels_eeg_ids)} eeg_ids with different labels in the last 6 columns')
print(f'These eeg_ids are: {different_labels_eeg_ids}')

In [None]:
train[train['eeg_id'].isin(different_labels_eeg_ids)].head(15)

In [None]:
train['eeg_id'].unique()

In [None]:
# read the first eeg file
eeg = pq.read_table(f"../data/raw/train_eegs/{train['eeg_id'].unique()[0]}.parquet").to_pandas()
eeg.shape

In [None]:
eeg.head()

In [None]:
# For each row in train, read the corresponding eeg file and extract the 50*200 samples from each eeg_label offset using groupby
# and apply
import pickle

def get_eegs(x, all_eegs, moving_max):
    eeg = pq.read_table(f"../data/raw/train_eegs/{x.eeg_id.iloc[0]}.parquet").to_pandas()
    all_eegs[x.eeg_id.iloc[0]] = eeg
    moving_max[0] = max(len(eeg), moving_max[0])



all_eegs = dict()
moving_max = np.array([0])
# read the spectrgormas per eeg id and extract all the 50*200 samples starting from each eeg label offset
# If the pickle file exists, load it, otherwise, create it
if os.path.exists('all_eegs.pkl'):
    all_eegs = pickle.load(open('all_eegs.pkl', 'rb'))
else:
    train.groupby('eeg_id').progress_apply(lambda x: get_eegs(x, all_eegs, moving_max))
    print('Saving all_eegs to pickle file')
    with open('all_eegs.pkl', 'wb') as f:
        pickle.dump(all_eegs, f)
    print('all_eegs saved to pickle file')
print(moving_max)


In [None]:
def get_offsets(x, all_offsets):
    all_offsets[x.eeg_id.iloc[0]] = list(map(int, x.eeg_label_offset_seconds.reset_index(drop=True).tolist()))


all_offsets = dict()
# read the spectrgormas per eeg id and extract all the 50*200 samples starting from each eeg label offset
train.groupby('eeg_id').progress_apply(lambda x: get_offsets(x, all_offsets))
len(all_offsets)


In [None]:
all_offsets[11127485]

In [None]:
# Make a pytorch dataset
from scipy.signal import decimate
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import time

class EEGDataset(Dataset):
    def __init__(self, all_eegs, metadata):
        self.all_eegs = all_eegs
        self.metadata = metadata
        self.column_names = None
        self.label_names = None

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):

        # get the eeg id from the idx in the metadata
        eeg_id = self.metadata.iloc[idx]['eeg_id']
        eeg_label_offset_seconds = int(self.metadata.iloc[idx]['eeg_label_offset_seconds'])
        eeg = self.all_eegs[eeg_id]
        eeg = eeg.iloc[eeg_label_offset_seconds*200:eeg_label_offset_seconds*200 + 50 * 200, :]
        self.column_names = eeg.columns
        # set nans in eegto 0 if there are any
        eeg = eeg.fillna(0)
        self.label_names = self.metadata.columns[-6:]
        labels = self.metadata.iloc[idx, -6:]
        labels /= sum(labels)
        eeg_arr = eeg.to_numpy(dtype=np.float32)
        labels_arr = labels.to_numpy(dtype=np.float32)
        return torch.from_numpy(eeg_arr), torch.from_numpy(labels_arr)



In [None]:
# Instantiate the eeg dataset
eeg_dataset = EEGDataset(all_eegs, train)

In [None]:
print(eeg_dataset.column_names)
print(eeg_dataset.label_names)

In [None]:
# Test the dataset
eeg, labels = eeg_dataset[27]
print(eeg_dataset.column_names)
eeg = eeg.numpy()
labels = labels.numpy()
# remove the mean from each column of eeg
eeg -= eeg.mean(axis=0)
# plot the EEGs
print(eeg_dataset.label_names)
print(labels)
# make 2 subfigures
fig, ax = plt.subplots(2, 1, figsize=(20, 10))
# plot the EEGs
ax[0].plot(eeg[:,:])
# take the rfft of each column in the eegs
eeg_fft = np.fft.rfft(eeg[:,:], axis=0)
# plot the fft of the EEGs
ax[1].plot(np.linspace(0, 200/2, len(eeg_fft)), np.abs(eeg_fft))
# show the plot
plt.show()



In [None]:
# After application of the lowpass filter, the 60Hz noise is removed from the EEGs
# Now subsample the eeg by a factor of 10 by using a decimation filter to also apply an anti aliasing filter
from scipy.signal import decimate
eeg_decimated = decimate(eeg[:,:], 5, axis=0)

decimated_eeg_fft = np.fft.rfft(eeg_decimated, axis=0)

import plotly.graph_objects as go
fig = go.Figure()
# loop over to plot the individual channels
for i in range(eeg.shape[1]):
    fig.add_trace(go.Scatter(x=np.arange(len(eeg[:,:])), y=eeg[:,i]))
fig.show()
#Alos plot the fft results using plotly
fig = go.Figure()
# loop over to plot the individual channels
for i in range(eeg_fft.shape[1]):
    fig.add_trace(go.Scatter(x=np.linspace(0, 200/10, len(decimated_eeg_fft)), y=np.abs(decimated_eeg_fft[:,i])))
fig.show()


In [None]:
import numpy as np
from scipy.signal import firwin

# Filter specifications
sample_rate = 200  # Example sample rate in Hz
cutoff_frequency = 17  # Desired cutoff frequency of the low-pass filter in Hz
numtaps = 101  # Number of taps in the FIR filter, determines the filter's length and complexity

# Design the low-pass FIR filter using the window method
fir_coefficients = firwin(numtaps, cutoff_frequency, fs=sample_rate, window='hamming')


In [None]:
import torch
import torch.nn.functional as F

# Assuming 'eeg' is your EEG data as a PyTorch tensor of shape (time, channels) on the CPU

# Transfer EEG data to GPU
eeg_gpu = torch.from_numpy(np.copy(eeg)).to('cuda')
print(eeg_gpu.shape)
# Convert FIR coefficients to a PyTorch tensor and transfer to GPU
fir_coefficients_tensor = torch.tensor(fir_coefficients, dtype=torch.float32).to('cuda')
fir_coefficients_tensor = fir_coefficients_tensor.view(1, 1, -1)  # Reshape for convolution

# Reshape EEG data for convolution
eeg_gpu_reshaped = eeg_gpu.transpose(0, 1).unsqueeze(1)  # From (time, channels) to (channels, 1, time)

# Apply the FIR filter using convolution
eeg_filtered = F.conv1d(eeg_gpu_reshaped, fir_coefficients_tensor, padding='same')

# Downsample the filtered signal
decimation_factor = 5
eeg_decimated_tensor = eeg_filtered[:, :, ::decimation_factor].squeeze(1).transpose(0, 1)


In [None]:
print(eeg_decimated_tensor.shape)
print(eeg_decimated.shape)

In [None]:
# Now convert the tensor to numpy and plot the results
eeg_decimated = eeg_decimated_tensor.cpu().numpy()
decimated_eeg_fft = np.fft.rfft(eeg_decimated, axis=0)
fig = go.Figure()
# loop over to plot the individual channels
for i in range(eeg_decimated.shape[1]):
    fig.add_trace(go.Scatter(x=np.arange(len(eeg_decimated[:,:])), y=eeg_decimated[:,i]))
fig.show()
#Alos plot the fft results using plotly
fig = go.Figure()
# loop over to plot the individual channels
for i in range(eeg_decimated.shape[1]):
    fig.add_trace(go.Scatter(x=np.linspace(0, 200/10, len(decimated_eeg_fft)), y=np.abs(decimated_eeg_fft[:,i])))
fig.show()

In [None]:
# Compute the spectrogram for each eeg and plot the resulting spectrograms
import torchaudio.transforms as T
import torchaudio

spec = T.Spectrogram(n_fft=511, hop_length=1)

for channel in range(eeg.shape[1]):
    eeg_chanel = np.copy(eeg_decimated[:,channel])

    eeg_tensor = torch.from_numpy(eeg_chanel)

    print(eeg_tensor.shape)
    # compute and plot the spectrgoram using pcolormesh
    spec_tensor = spec(eeg_tensor)
    print(spec_tensor.shape)
    plt.figure(figsize=(20, 10))

    plt.pcolormesh(spec_tensor.log2().numpy())
    plt.show()