<a href="https://colab.research.google.com/github/NoahIslam/ShinozakiLabEEGMLModels/blob/main/EEG_Preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install mne


Collecting mne
  Downloading mne-1.5.0-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: mne
Successfully installed mne-1.5.0


In [None]:
import glob
import pandas as pd
import os
import datetime
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
import torchvision.models as models
from torchvision import transforms
import torch.optim as optim
import torch.nn as nn
import mne
import torch.nn.functional as F
dtype = torch.float32

In [None]:
edf_loc = "/content/drive/MyDrive/EDF Files_copy/"
xcel_loc = "/content/drive/MyDrive/EEG_Clinical_Data_Master_File_Broadcasted.xlsx"

#gets list of all files (copied from original notebook)
def get_list_of_files(filepath):
  entries = os.listdir(filepath)
  files = []
  for item in entries:
    if os.path.isdir(os.path.join(filepath, item)):
      files.extend(get_list_of_files(os.path.join(filepath, item)))
  files.extend([os.path.join(filepath, x) for x in os.listdir(filepath) if os.path.isfile(os.path.join(filepath, item)) and os.path.join(filepath, item)[-4:] == ".edf"])
  return files

files = get_list_of_files(edf_loc)
files_1 = files

In [None]:
print(len(files))

4048


In [None]:
#get xcel as csv
def xcel_csv(file_loc):
  df = pd.read_excel(file_loc)
  df.to_csv('EEG_Data.csv', index=False)

  return df

xcel = xcel_csv(xcel_loc)

In [None]:
raw_data = []

def to_datetime(timedate):
  rem_paren = timedate.split('/')[-1]
  filename = rem_paren.split("_")
  EEG_ID = filename[0]
  date = filename[1]
  if date[len(date)-4:len(date)] == '2019' or date[len(date)-4:len(date)] == '2020' or date[len(date)-4:len(date)] == '2018' or date[len(date)-4:len(date)] == '2017' or date[len(date)-4:len(date)] == '2016':
    month = date[len(date)-8:len(date)-6]
    day = date[len(date)-6:len(date)-4]
    year = date[len(date)-4:len(date)]
  else:
    year = date[len(date)-8:len(date)-4]
    month = date[len(date)-4:len(date)-2]
    day = date[len(date)-2:len(date)]

  try:
    return (int(EEG_ID), datetime.datetime(int(year), int(month), int(day), 0, 0, 0))
  except Exception as e:
    return ()


def csv_to_dict(files, df):
  filename_to_score = {}
  shitlist = []
  for curr_file in files:
    inputs = to_datetime(curr_file)
    if inputs != ():
      has_delirium = df.loc[(df["Date"] == inputs[1]) & (df['EEG STUDY ID'] == inputs[0]), 'Delirious (0=N, 1=Y)']
    if (len(has_delirium) == 0 or has_delirium.values[0] == '-' or has_delirium.values[0] == '?' or np.isnan(has_delirium.values[0])):
      shitlist.append(curr_file)
    else:
        # Filter out Fukuoda data
        if curr_file[-5] != 'A' and curr_file[-5] != 'B':
            filename_to_score[curr_file] = has_delirium.values[0]
  return filename_to_score

filename_to_score = csv_to_dict(files_1, xcel)

In [None]:
raw_data = []
sfreq = 500

def standardize_data(data):
    return (data - np.mean(data)) / np.std(data)

def bandpass_filter(data, low_freq, high_freq, sfreq):
    return mne.filter.filter_data(data, sfreq, l_freq=low_freq, h_freq=high_freq, verbose='ERROR')


def edf_to_arr(edf_path):
  try:
    raw = mne.io.read_raw_edf(edf_path, verbose='ERROR')
    duration = raw.times[-1]  # this gives the duration in seconds
    if duration < 128:
        print(f"Data duration for {edf_path} is less than 128 seconds. Skipping...")
        return -1
    raw = raw.crop(tmax=128)

    raw_resampled = raw.resample(sfreq, verbose='ERROR')
    eeg_data = raw_resampled.get_data()[0]
    eeg_data = bandpass_filter(eeg_data, 0.5, 50, sfreq)
    eeg_data = standardize_data(eeg_data)
    return eeg_data.astype(np.float32)

  except Exception as e:
    return -1

In [None]:
data = []
labels = []
i = 1
for filename, label in filename_to_score.items():
    print(i)
    i+=1
    arr = edf_to_arr(filename)
    if type(arr) != int and (label == 1 or label == 0):
        data.append(arr)
        labels.append(label)

In [None]:
print(len(data))
print(len(labels))

In [None]:
for label in labels:
  if label != 1 and label != 0:
    print(label)

In [None]:
import pickle

# To save the data
with open('data_cropped.pkl', 'wb') as f:
    pickle.dump(data, f)

with open('labels_cropped.pkl', 'wb') as f:
    pickle.dump(labels, f)

In [None]:
# Assuming you've already mounted Google Drive
import shutil

shutil.move('data_cropped.pkl', '/content/drive/MyDrive/data_cropped.pkl')
shutil.move('labels_cropped.pkl', '/content/drive/MyDrive/labels_cropped.pkl')


'/content/drive/MyDrive/labels_cropped.pkl'

In [None]:
print(len(data))

2738


In [None]:
print(len(labels))

2738


In [None]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42, stratify=labels)


In [None]:
from sklearn.utils import resample

# Combine the data and labels for easier resampling
combined = list(zip(X_train, y_train))

# Separate classes
class_0 = [item for item in combined if item[1] == 0]
class_1 = [item for item in combined if item[1] == 1]

# Upsample the minority (this assumes class_1 is the minority. If not, switch them around.)
class_1_upsampled = resample(class_1, replace=True, n_samples=len(class_0), random_state=42)

# Combine and shuffle
balanced_data = class_0 + class_1_upsampled
np.random.shuffle(balanced_data)

X_train_balanced = [item[0] for item in balanced_data]
y_train_balanced = [item[1] for item in balanced_data]


In [None]:
from torch.nn.utils.rnn import pack_padded_sequence

def collate_fn(batch):
    # Separate sequences and labels
    sequences, labels = zip(*batch)

    # Get sequence lengths and sort in descending order
    lengths = torch.tensor([len(seq) for seq in sequences])
    lengths, order = lengths.sort(0, descending=True)

    sequences_sorted = [sequences[i] for i in order]
    labels_sorted = torch.tensor(labels, dtype=torch.float32)[order]

    # Pad sequences
    sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences_sorted, batch_first=True)

    return sequences_padded, lengths, labels_sorted



In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

# Convert to PyTorch tensors
X_train_tensors = [torch.tensor(d, dtype=torch.float32).unsqueeze(1) for d in X_train_balanced]  # Adding channel dimension
y_train_tensors = torch.tensor(y_train_balanced, dtype=torch.float32).unsqueeze(1)

X_test_tensors = [torch.tensor(d, dtype=torch.float32).unsqueeze(1) for d in X_test]  # Adding channel dimension
y_test_tensors = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

# Create datasets
train_dataset = list(zip(X_train_tensors, y_train_tensors))
test_dataset = list(zip(X_test_tensors, y_test_tensors))

# Create dataloaders
BATCH_SIZE = 10
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=BATCH_SIZE, collate_fn=collate_fn)


In [None]:
import torch.nn as nn

class DeliriumLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(DeliriumLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 2)  # Two output units for binary classification

    def forward(self, x, lengths):
        # Handle variable-length sequences
        packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True)
        packed_out, _ = self.lstm(packed_x)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)

        out = self.fc(out[:, -1, :])  # Take the output from the last time step
        return out



In [None]:
# Hyperparameters
INPUT_SIZE = 1  # 1-dimensional input (one EEG channel)
HIDDEN_SIZE = 64
NUM_LAYERS = 2
LR = 0.001

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = DeliriumLSTM(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS).to(device)
criterion = nn.CrossEntropyLoss()  # Use cross-entropy loss for binary classification
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
torch.cuda.empty_cache()


In [None]:
EPOCHS = 10

for epoch in range(EPOCHS):
    model.train()
    for sequences, lengths, labels in train_loader:
        sequences, lengths, labels = sequences.to(device), lengths.to(device), labels.to(device)

        # Forward pass
        outputs = model(sequences, lengths)
        loss = criterion(outputs, labels.long())

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for sequences, lengths, labels in test_loader:
            sequences, lengths, labels = sequences.to(device), lengths.to(device), labels.to(device)

            outputs = model(sequences, lengths)
            val_loss += criterion(outputs, labels.long()).item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.long()).sum().item()

    print(f"Epoch {epoch + 1}/{EPOCHS}, Training Loss: {loss.item():.4f}, Validation Loss: {val_loss/len(test_loader):.4f}, Validation Accuracy: {100 * correct / total:.2f}%")


In [None]:
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for sequences, lengths, labels in test_loader:
        sequences, lengths, labels = sequences.to(device), lengths.to(device), labels.to(device)

        outputs = model(sequences, lengths)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels.long()).sum().item()

print(f"Accuracy on test set: {100 * correct / total:.2f}%")