In [None]:
# @title
from google.colab import drive
drive.mount('/content/drive')
import os
import numpy as np
import scipy 
import matplotlib as plt
import torch 
from torch.utils.data import DataLoader as DL
from torch.utils.data import TensorDataset as TData

new_base = '/content/drive/MyDrive/EEGData'
files = os.listdir(new_base)
paths = [os.path.join(new_base, file) for file in files]
file_splits = [file.rstrip('.npy').split('_')[3] for file in files]
sessions = map(int, set(file_splits))
sorted_sessions = sorted(list(sessions))[8:]
sorted_sessions.remove(31)
sorted_sessions.remove(34)
sorted_sessions.remove(41)
print(sorted_sessions) # 17 sessions in total  45-57 new
# may not be needed if using k-fold testing

import random
chosen_sessions = random.sample(sorted_sessions, k=12)
val_sessions, test_sessions = chosen_sessions[:6], chosen_sessions[6:]
print(val_sessions, test_sessions)
# done for electrode grouping to perform 2D convolution
def channel_rearrangment(sig, channel_order):
    channel_order = [channel - 1 for channel in channel_order]
    reindexed = np.zeros_like(sig)
    for i, ind in enumerate(channel_order):
        reindexed[i] = sig[ind]
    return reindexed

ordered_channels = [1, 9, 11, 3, 2, 12, 10, 4, 13, 5, 15, 7, 14, 16, 6, 8]
# applying a bandpass filter
def bandpass_filter(signal, crit_freq = [1, 40], sampling_freq = 125, plot = False, channel = 0):
  order = 4

  b, a = scipy.signal.butter(order, crit_freq, btype = 'bandpass', fs = sampling_freq)
  processed_signal = scipy.signal.filtfilt(b, a, signal, 1)

  if plot == True:
    plt.figure()
    plt.xlabel('Time')
    plt.ylabel(f'Normalized amplitude of channel {channel}')
    plt.title(f'{crit_freq[0]}-{crit_freq[1]}Hz bandpass filter')
    signal_min = np.full((signal.shape[1], signal.shape[0]), np.min(signal, 1)).transpose()
    signal_max = np.full((signal.shape[1], signal.shape[0]), np.max(signal, 1)).transpose()
    normed_signal = (signal - signal_min) / (signal_max - signal_min)
    filtered_min = np.full((processed_signal.shape[1], processed_signal.shape[0]), np.min(processed_signal, 1)).transpose()
    filtered_max = np.full((processed_signal.shape[1], processed_signal.shape[0]), np.max(processed_signal, 1)).transpose()
    normed_filt = (processed_signal - filtered_min) / (filtered_max - filtered_min)
    plt.plot(np.arange(normed_signal[channel].size), normed_signal[channel], label = 'Input')
    plt.plot(np.arange(normed_filt[channel].size), normed_filt[channel], label = 'Transformed')
    plt.legend()

  return processed_signal


# function to segment eeg data based on sampling freq(Hz), window_size(s), and window_shift(s)
def segmentation(signal, sampling_freq=125, window_size=1, window_shift=0.016):
  w_size = int(sampling_freq * window_size)
  w_shift = int(sampling_freq * window_shift)
  segments = []
  i = 0
  while i + w_size <= signal.shape[1]:
    segments.append(signal[:, i: i + w_size])
    i += w_shift
  return segments

# applying all preprocessing steps to create train and test data
train_eeg = []
train_labels = []
valid_eeg = []
valid_labels = []
test_eeg = []
test_labels = []
for i in range(len(files)):
  name = files[i]
  details = name.rstrip('.npy').split('_')[2:] # getting session details from file name
  sig = np.load(paths[i]) # loading signal
  sig = sig[:, 1:] # removing first time step because it is inaccurate
  if sig.shape[1] == 0 or int(details[1]) not in sorted_sessions: # excluding empty sample elements
    #print(name)
    continue
  reindexed_signal = channel_rearrangment(sig, ordered_channels)
  filtered_sig = bandpass_filter(reindexed_signal, [5, 40], 125) # bandpass filter
  normed_sig = (filtered_sig - np.mean(filtered_sig, 1, keepdims=True)) / np.std(filtered_sig, 1, keepdims=True) # standard scaling
  if np.isnan(normed_sig).any(): # excluding sample elements with nans
    print(name)
    continue
  signals = segmentation(normed_sig, 128, window_size = 1.5, window_shift = 0.0175) # segmentation
  labels = [int(details[0])] * len(signals)
  if int(details[1]) in test_sessions:
    test_eeg.extend(signals)
    test_labels.extend(labels)
  elif int(details[1]) in val_sessions:
    valid_eeg.extend(signals)
    valid_labels.extend(labels)
  else:
    train_eeg.extend(signals)
    train_labels.extend(labels)

train_eeg_tensor = torch.zeros((len(train_eeg), train_eeg[0].shape[0], train_eeg[0].shape[1])) # untransposed dimensions 1 and 2
valid_eeg_tensor = torch.zeros((len(valid_eeg), valid_eeg[0].shape[0], valid_eeg[0].shape[1]))
test_eeg_tensor = torch.zeros((len(test_eeg), test_eeg[0].shape[0], test_eeg[0].shape[1]))
for i in range(len(train_eeg)):
  tens = torch.from_numpy(train_eeg[i].copy()) # no longer transposing before conversion to tensor
  train_eeg_tensor[i] = tens
for i in range(len(valid_eeg)):
  tens = torch.from_numpy(valid_eeg[i].copy())
  valid_eeg_tensor[i] = tens
for i in range(len(test_eeg)):
  tens = torch.from_numpy(test_eeg[i].copy())
  test_eeg_tensor[i] = tens
train_label_tensor = torch.zeros(len(train_labels), 2)
valid_label_tensor = torch.zeros(len(valid_labels), 2)
test_label_tensor = torch.zeros(len(test_labels), 2)
class_to_idx = {1:0, 3:1}
for i in range(len(train_labels)):
  label = class_to_idx[train_labels[i]]
  train_label_tensor[i][label] = 1
for i in range(len(valid_labels)):
  label = class_to_idx[valid_labels[i]]
  valid_label_tensor[i][label] = 1
for i in range(len(test_labels)):
  label = class_to_idx[test_labels[i]]
  test_label_tensor[i][label] = 1

train_ds = TData(train_eeg_tensor, train_label_tensor)
valid_ds = TData(valid_eeg_tensor, valid_label_tensor)
test_ds = TData(test_eeg_tensor, test_label_tensor)
train_dl = DL(train_ds, batch_size=64, shuffle= True, drop_last = True)
valid_dl = DL(valid_ds, batch_size=64, shuffle= True, drop_last = True)
test_dl = DL(test_ds, batch_size=16, shuffle = True, drop_last = True)