In [None]:
import mne
import numpy as np
import pandas as pd 
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from glob import glob
from tqdm import tqdm
import os

mne.set_log_level('WARNING')

event_id_all = {'0Suff w/ Lat.': 1,
                '0Suff w/o Lat.': 2,
                '0Suff NW': 4,

                '1Suff w/ Lat.': 11,
                '1Suff w/o Lat.': 12,
                '1Suff PseudoStemNW': 14,
                '1Suff RealStemNW': 15,

                '2Suff w/ Lat.': 21,
                '2Suff w/o Lat.': 22,
                '2Suff Composite': 23,
                '2Suff PseudoStemNW': 24,
                '2Suff RealStemNW': 25
                }

def get_condition(id):
    for event in event_id_all:
        if event_id_all[event] == id:
            return event

import mne
from glob import glob
from tqdm import tqdm

epoch_files = glob('/scratch/alr664/multiple_affix/meg/*/*epo.fif')

In [None]:
root = '/scratch/alr664/multiple_affix'
meg = root + '/meg'
logs = root + '/logs'

full_dataset = ["A0394", "A0421", "A0446", "A0451", "A0468", "A0484", "A0495", "A0502", "A0503", "A0508", 
                "A0509", "A0512", "A0513", "A0514", "A0516", "A0517", "A0518", "A0519", "A0520", "A0521", 
                "A0522", "A0523", "A0524", "A0525"]

subjects = [subj for subj in os.listdir(meg) if not subj.startswith('.')]
subjects
len(subjects)

epoch_files = []

for subject in full_dataset:
    subj_epoch_path = meg + '/' + subject + '/' + subject + '_rejection-epo.fif'
    print(subj_epoch_path)
    epoch_files.append(subj_epoch_path)

In [None]:
meg_full_data = [] 
labels_full_data = []
for epoch_file in tqdm(epoch_files):
    print(epoch_file)
    epochs = mne.read_epochs(epoch_file)
    epochs = epochs.resample(125)
    epochs = epochs.crop(tmin=0. , tmax= 0.6)
    epochs = mne.epochs.combine_event_ids(epochs, ['0Suff NW', '0Suff w/o Lat.', '0Suff w/ Lat'],  {'0Suff': 100},  True)
    epochs = mne.epochs.combine_event_ids(epochs, ['1Suff PseudoStemNW', '1Suff RealStemNW', '1Suff w/ Lat.', '1Suff w/o Lat.'],  {'1Suff': 101},  True)
    epochs = mne.epochs.combine_event_ids(epochs, ['2Suff RealStemNW', '2Suff PseudoStemNW', '2Suff w/ Lat.', '2Suff w/o Lat.', '2Suff Composite'],  {'2Suff': 102},  True)  
    meg_full_data.append(epochs.get_data())
    labels_full_data.append(epochs.events[:, 2])
    del epochs

In [None]:
meg_full_data = np.vstack(meg_full_data)
labels_full_data = np.concatenate(labels_full_data)
np.save('./meg_full_data.npy', meg_full_data)
np.save('./labels.npy', labels_full_data)
print("meg full data shape: ", meg_full_data.shape, labels_full_data.shape)

In [None]:
labels_full_data = np.where(labels_full_data == 100, 0, labels_full_data)
labels_full_data = np.where(labels_full_data == 101, 1, labels_full_data)
labels_full_data = np.where(labels_full_data == 102, 2, labels_full_data)
np.unique(labels_full_data)

In [None]:
## Classification using GRU

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, GRU, Dropout
from tensorflow.keras.optimizers import Adam
import numpy as np
from tensorflow.keras.utils import to_categorical

# Define the model architecture with increased complexity
input_layer = Input(shape=meg_full_data.shape[1:])

# Increase the number of units and add more layers
x = GRU(128, return_sequences=True)(input_layer)
x = Dropout(0.2)(x)  # Add dropout for regularization
x = GRU(128, return_sequences=True)(x)
x = Dropout(0.2)(x)  # Add dropout for regularization
x = GRU(128)(x)  # Last GRU layer does not return sequences
x = Dropout(0.2)(x)  # Add dropout for regularization
 
# Increase the complexity of the model further by adding Dense layers before the output
x = Dense(64, activation='relu')(x)
x = Dropout(0.2)(x)  # Add dropout for regularization
x = Dense(64, activation='relu')(x)
x = Dropout(0.2)(x)  # Add dropout for regularization

output_layer = Dense(len(np.unique(labels_full_data)), activation='softmax')(x)


In [None]:
X = meg_full_data.copy()
y = labels_full_data.copy()

losses = []
metrics = []

for subj in range(24):
    model = Model(inputs=input_layer, outputs=output_layer)
    # Compile the model
    model.compile(optimizer=Adam(learning_rate=0.001), loss="sparse_categorical_crossentropy", metrics=['accuracy'])
    print("Treating subject ", subj," as the test subj:")
    start = subj*1886
    end = 1886*(subj+1)
    X_test = X[start:end]
    y_test = y[start:end]
    X_train = np.delete(X, slice(start, end), axis=0)
    y_train = np.delete(y, slice(start, end), axis=0)
    print(X_train.shape, y_train.shape)
    print(X_test.shape, y_test.shape)
    model.fit(X_train, y_train, batch_size=64, epochs=30, verbose=True)
    loss, accuracy = model.evaluate(X_test, y_test, verbose=1)
    print(f'Test Accuracy: {accuracy:.4f}')
    losses.append(loss)
    metrics.append(accuracy)
    

In [None]:
print("""\n Test Performance: 
              Loss: {:.4f} +/- {:.4f}.
              Metric: {:.4f} +/- {:.4f}"""
              .format(np.mean(losses), 
                      np.std(losses),
                      np.mean(metrics), 
                      np.std(metrics)))