In [None]:
import numpy as np

import mne
mne.set_log_level(verbose='CRITICAL')
from mne.datasets import multimodal

import os
import glob
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
tf.get_logger().setLevel('ERROR')
tf.autograph.set_verbosity(0)

import mneflow
print(mneflow.__version__)

In [None]:
root = '/scratch/alr664/multiple_affix'
meg = root + '/meg'
logs = root + '/logs'

full_dataset = ["A0394", "A0421", "A0446", "A0451", "A0468", "A0484", "A0495", "A0502", "A0503", "A0508", 
                "A0509", "A0512", "A0513", "A0514", "A0516", "A0517", "A0518", "A0519", "A0520", "A0521", 
                "A0522", "A0523", "A0524", "A0525"]

In [None]:
subjects = [subj for subj in os.listdir(meg) if not subj.startswith('.')]
subjects
len(subjects)

In [None]:
epochs_list = []

for subject in full_dataset:
    subj_epoch_path = meg + '/' + subject + '/' + subject + '_rejection-epo.fif'
    print(subj_epoch_path)
    subj_epoch = mne.read_epochs(subj_epoch_path)
    subj_epoch = subj_epoch.resample(125)
    subj_epoch = mne.epochs.combine_event_ids(subj_epoch, ['0Suff NW', '0Suff w/o Lat.', '0Suff w/ Lat'],  {'0Suff': 100},  True)
    subj_epoch = mne.epochs.combine_event_ids(subj_epoch, ['1Suff PseudoStemNW', '1Suff RealStemNW', '1Suff w/ Lat.', '1Suff w/o Lat.'],  {'1Suff': 101},  True)
    subj_epoch = mne.epochs.combine_event_ids(subj_epoch, ['2Suff RealStemNW', '2Suff PseudoStemNW', '2Suff w/ Lat.', '2Suff w/o Lat.', '2Suff Composite'],  {'2Suff': 102},  True)
    subj_epoch_mag = subj_epoch.pick_types(meg='mag')
    epochs_list.append(subj_epoch_mag)

In [None]:
from collections import Counter
subj = 1
for epochs in epochs_list:
    print("Subject: ", subj)
    print(epochs.get_data().shape)
    print(Counter([event[2] for event in epochs.events]))
    epochs = epochs.crop(tmin=0. , tmax= 0.6)
    subj+=1

In [None]:
path = './lfcnn-3/'
data_id = 'meg_epochs_loso'

import_opt = dict(path=path,
                  data_id=data_id,
                  input_type='trials',
                  target_type='int',
                  n_folds=5,
                  test_set = 'loso',
                  overwrite=True,
                  picks={'meg':'grad'},
                  scale=False,
                  crop_baseline=False,
                  decimate=None,
                  )

meta = mneflow.produce_tfrecords(epochs_list, **import_opt)

In [None]:
dataset = mneflow.Dataset(meta, train_batch=64)

In [None]:
lfcnn_params = dict(n_latent=32,
                  filter_length=7,
                  nonlin = tf.nn.relu,
                  padding = 'SAME',
                  pooling = 2,
                  pool_type='max',
                  dropout = .5,
                  l1_scope = ["weights"],
                  l1=3e-4)

meta.update(model_specs=lfcnn_params)

model = mneflow.models.LFCNN(meta)
model.build()

In [None]:
model.train(n_epochs=20, eval_step=50, mode='loso')

In [None]:
test_loss, test_acc = model.evaluate(meta.data['test_paths'])
print("Test set: Loss = {:.4f} Accuracy = {:.4f}".format(test_loss, test_acc))