In [1]:
import os
import numpy as np
import scipy.io as sio
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, BatchNormalization, LSTM, Bidirectional, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix
import time
import pandas as pd


## TPUs in Colab
https://colab.research.google.com/notebooks/tpu.ipynb#scrollTo=ovFDeMgtjqW4

In [10]:
import tensorflow as tf
print("Tensorflow version " + tf.__version__)

try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.TPUStrategy(tpu)

Tensorflow version 2.12.0
Running on TPU  ['10.6.142.98:8470']




KeyboardInterrupt: ignored

## Set up TPU environment:

In [2]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)


## Load and preprocess data:

In [3]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [90]:
import os
import numpy as np
import h5py
import shutil
import gc
import time
from scipy.interpolate import interp1d

In [91]:
def interpolate_channels(eeg_data, target_channels):
    current_channels = eeg_data.shape[0]
    
    if current_channels < target_channels:
        # Generate the new channel indices
        old_indices = np.linspace(0, current_channels - 1, current_channels)
        new_indices = np.linspace(0, current_channels - 1, target_channels)

        # Interpolate the data
        interpolated_data = []
        for i in range(eeg_data.shape[1]):
            f = interp1d(old_indices, eeg_data[:, i], kind='cubic', fill_value='extrapolate')
            interpolated_data.append(f(new_indices))

        eeg_data = np.array(interpolated_data).T

    return eeg_data


In [92]:
def load_subject_data(subject_id):
    file_path = os.path.join(f'/content/drive/MyDrive/EEG_detection/EEG_data/subject_{subject_id:02d}.mat')
    with h5py.File(file_path, 'r') as f:
        eeg_data_struct = f['eeg_data']
        labels_struct = f['labels']

        # Access the 'filteredData' cell within the 'eeg_data' struct
        eeg_data = eeg_data_struct['filteredData']
        # print(eeg_data.shape)

        # Access the 'sumLabel' cell within the 'labels' struct
        labels = labels_struct['sumLabel']
        # print(labels.shape)
    
        # Interpolate channels to have the same number (e.g., 26 channels)
        eeg_data = interpolate_channels(eeg_data, 26)
        print("interpolate: ", eeg_data.shape)

    return eeg_data, labels


def preprocess_data(eeg_data, labels):
    window_size = 2 * 256  # 2 seconds * 256 samples per second
    step_size = int(window_size * 0.25)  # 75% overlap

    eeg_segments = []
    label_segments = []

    print(f"EEG data shape: {eeg_data.shape}")
    print(f"Labels shape: {labels.shape}")

    for i in range(0, eeg_data.shape[1] - window_size + 1, step_size):
        eeg_segments.append(eeg_data[:, i:i + window_size])
        label_segments.append(labels[:, i:i + window_size])

    eeg_segments = np.array(eeg_segments)
    label_segments = np.array(label_segments)

    return eeg_segments, label_segments


In [93]:
def load_subject15_data(subject_id):
    file_path = (f'/content/drive/MyDrive/EEG_detection/闹心玩意/f_15.mat')
    with h5py.File(file_path, 'r') as f:
        eeg_data = f['filteredData']
        labels = f['sumLabel']

        print(eeg_data.shape)
        # print(labels.shape)

        # Interpolate channels to have the same number (e.g., 26 channels)
        eeg_data = interpolate_channels(eeg_data, 26)
        print(eeg_data.shape)

    return eeg_data, labels

print(f"Processing subject {subject_id}...")
eeg_data_loaded, labels_loaded = load_subject15_data(subject_id)

Processing subject 24...
(36873216, 26)
(36873216, 26)


In [94]:
l = [5,6,7,8,9,10,11,12,13,14]

In [95]:
import os
import time
import shutil
import tensorflow as tf

# Loop through all subjects and preprocess data
for subject_id in range(2,25):
    start_time = time.time()
    print(f"Processing subject {subject_id}...")
    if subject_id == 15:
      eeg_data_loaded, labels_loaded = load_subject15_data(subject_id)
    elif subject_id == 4:
      continue
    else:
      eeg_data_loaded, labels_loaded = load_subject_data(subject_id)

    # eeg_data, labels = preprocess_data(eeg_data_loaded, labels_loaded)

    # # Save preprocessed data locally
    # local_output_dir = '/content/preprocessed'
    # if not os.path.exists(local_output_dir):
    #     os.makedirs(local_output_dir)
    
    # dataset = tf.data.Dataset.from_tensor_slices((eeg_data, labels))
    # tf.data.experimental.save(dataset, os.path.join(local_output_dir, f'dataset_subject_{subject_id}'))
    # print(f"Finished processing subject {subject_id}.")

    # # Copy the preprocessed data to Google Drive
    # google_drive_output_dir = '/content/drive/MyDrive/EEG_detection/preprocessed'
    # shutil.copytree(os.path.join(local_output_dir, f'dataset_subject_{subject_id}'), os.path.join(google_drive_output_dir, f'dataset_subject_{subject_id}'))

    # elapsed_time = time.time() - start_time
    # print(f"Finished processing subject {subject_id} in {elapsed_time:.2f} seconds.")


Processing subject 2...
interpolate:  (32501504, 22)
Processing subject 3...
interpolate:  (35022336, 22)
Processing subject 4...
Processing subject 5...
interpolate:  (35944960, 22)
Processing subject 6...
interpolate:  (61502976, 22)
Processing subject 7...
interpolate:  (61795328, 22)
Processing subject 8...
interpolate:  (18437888, 22)
Processing subject 9...
interpolate:  (62550528, 22)
Processing subject 10...
interpolate:  (46101504, 22)
Processing subject 11...
interpolate:  (32065792, 22)
Processing subject 12...
interpolate:  (19065856, 22)
Processing subject 13...
interpolate:  (30412800, 18)
Processing subject 14...
interpolate:  (23961600, 22)
Processing subject 15...
(36873216, 26)
(36873216, 26)
Processing subject 16...
interpolate:  (17510400, 18)
Processing subject 17...
interpolate:  (19359744, 18)
Processing subject 18...
interpolate:  (32840960, 18)
Processing subject 19...
interpolate:  (27582976, 18)
Processing subject 20...
interpolate:  (25437696, 22)
Processing

## Model:

In [None]:
def create_model(input_shape):
    with strategy.scope():
        model = Sequential([
            Conv1D(16, kernel_size=5, activation='relu', input_shape=input_shape),
            MaxPooling1D(pool_size=2),
            BatchNormalization(),
            Conv1D(32, kernel_size=5, activation='relu'),
            MaxPooling1D(pool_size=2),
            BatchNormalization(),
            Conv1D(64, kernel_size=5, activation='relu'),
            MaxPooling1D(pool_size=2),
            BatchNormalization(),
            LSTM(64, return_sequences=True),
            Bidirectional(LSTM(64)),
            Dense(1, activation='sigmoid')
        ])

        model.compile(optimizer=Adam(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])
    return model


## compute metrics:

In [None]:
def compute_metrics(y_true, y_pred):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    accuracy = accuracy_score(y_true, y_pred)
    return sensitivity, specificity, accuracy

def get_seizure_start_times(y_true, y_pred):
    marked_seizures = np.where(np.diff(np.concatenate(([False], y_true, [False]))).astype(int) == 1)[0]
    detected_seizures = np.where(np.diff(np.concatenate(([False], y_pred, [False]))).astype(int) == 1)[0]


## train

In [None]:
# sequence the pipeline with batched training
def train_and_evaluate_on_fold(fold, batch_size=4):
    test_subject_ids = list(range(fold * 4 + 1, (fold + 1) * 4 + 1))
    train_subject_ids = list(range(1, fold * 4 + 1)) + list(range((fold + 1) * 4 + 1, 25))

    model = create_model(input_shape)

    # Train the model in batches
    for _ in range(10):  # 10 epochs
        for i in range(0, len(train_subject_ids), batch_size):
            batch_subject_ids = train_subject_ids[i:i + batch_size]
            X_train_batch, y_train_batch = [], []

            for subject_id in batch_subject_ids:
                eeg_data, labels = load_subject_data(subject_id)
                segments, segment_labels = preprocess_data(eeg_data, labels)
                X_train_batch.append(segments)
                y_train_batch.append(segment_labels)

            X_train_batch = np.concatenate(X_train_batch, axis=0)
            y_train_batch = np.concatenate(y_train_batch, axis=0)

            model.train_on_batch(X_train_batch, y_train_batch)

    # Evaluate the model
    subjects_metrics = []
    for subject_id in test_subject_ids:
        eeg_data, labels = load_subject_data(subject_id)
        segments, segment_labels = preprocess_data(eeg_data, labels)
        X_test, y_test = segments, segment_labels

        start_time = time.time()
        y_pred_prob = model.predict(X_test)
        processing_time = time.time() - start_time

        y_pred = (y_pred_prob > 0.5).astype(int).reshape(-1)

        sensitivity, specificity, accuracy = compute_metrics(y_test, y_pred)
        roc = roc_auc_score(y_test, y_pred_prob)

        marked_seizures, detected_seizures = get_seizure_start_times(y_test, y_pred)
        seizure_durations = [np.min(np.abs(marked - detected_seizures)) / 256 for marked in marked_seizures]
        seizure_detection_time = np.mean(seizure_durations)

        subjects_metrics.append((subject_id, sensitivity, specificity, accuracy, roc, len(marked_seizures), len(detected_seizures), seizure_detection_time, processing_time))

    return subjects_metrics

input_shape = (512, 16)  # 2 seconds * 256 samples/second, 16 channels

all_subjects_metrics = []
n_folds = 6

for fold in range(n_folds):
    print(f"Training and evaluating model on fold {fold + 1}")
    subjects_metrics = train_and_evaluate_on_fold(fold)
    all_subjects_metrics.extend(subjects_metrics)

all_subjects_metrics.sort(key=lambda x: x[0])  # Sort the results by subject ID

results_df = pd.DataFrame(all_subjects_metrics, columns=['Subject', 'Sensitivity', 'Specificity', 'Accuracy', 'ROC', 'Marked Seizures', 'Detected Seizures', 'Seizure Detection Time (s)', 'Processing Time (s)'])
display(results_df)
