In [1]:
import numpy as np

import mne
mne.set_log_level(verbose='CRITICAL')
from mne.datasets import multimodal

import os
import glob
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
tf.get_logger().setLevel('ERROR')
tf.autograph.set_verbosity(0)

import mneflow
print(mneflow.__version__)

0.5.6


In [2]:
root = '/scratch/alr664/multiple_affix'
meg = root + '/meg'
logs = root + '/logs'

full_dataset = ["A0394", "A0421", "A0446", "A0451", "A0468", "A0484", "A0495", "A0502", "A0503", "A0508", 
                "A0509", "A0512", "A0513", "A0514", "A0516", "A0517", "A0518", "A0519", "A0520", "A0521", 
                "A0522", "A0523", "A0524", "A0525"]

In [3]:
subjects = [subj for subj in os.listdir(meg) if not subj.startswith('.')]
subjects
len(subjects)

24

In [4]:
epochs_list = []

for subject in full_dataset:
    subj_epoch_path = meg + '/' + subject + '/' + subject + '_rejection-epo.fif'
    print(subj_epoch_path)
    subj_epoch = mne.read_epochs(subj_epoch_path)
    subj_epoch = subj_epoch.resample(125)
    subj_epoch_mag = subj_epoch.pick_types(meg='mag')
    epochs_list.append(subj_epoch_mag)

print("Done!")

epochs = mne.concatenate_epochs(epochs_list, on_mismatch='ignore')

print("Done!.")

/scratch/alr664/multiple_affix/meg/A0394/A0394_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0421/A0421_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0446/A0446_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0451/A0451_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0468/A0468_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0484/A0484_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0495/A0495_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0502/A0502_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0503/A0503_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0508/A0508_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0509/A0509_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0512/A0512_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0513/A0513_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0514/A0514_rejection-epo.fif
/scratch/alr664/multiple_affix/meg/A0516/A0516_rejection-epo.fif
/scratch/alr664/multiple_

In [5]:
epochs.get_data().shape

(45264, 207, 100)

In [6]:
epochs.event_id

{'0Suff w/ Lat': 1,
 '0Suff w/o Lat.': 2,
 '0Suff NW': 4,
 '1Suff w/ Lat.': 11,
 '1Suff w/o Lat.': 12,
 '1Suff PseudoStemNW': 14,
 '1Suff RealStemNW': 15,
 '2Suff w/ Lat.': 21,
 '2Suff w/o Lat.': 22,
 '2Suff Composite': 23,
 '2Suff PseudoStemNW': 24,
 '2Suff RealStemNW': 25}

In [7]:
from collections import Counter
Counter([event[2] for event in epochs.events])

Counter({22: 7200,
         4: 6168,
         12: 6072,
         2: 5448,
         24: 4992,
         25: 4656,
         14: 3648,
         15: 3168,
         23: 1536,
         21: 912,
         11: 744,
         1: 720})

In [8]:
epochs.times

array([-0.2  , -0.192, -0.184, -0.176, -0.168, -0.16 , -0.152, -0.144,
       -0.136, -0.128, -0.12 , -0.112, -0.104, -0.096, -0.088, -0.08 ,
       -0.072, -0.064, -0.056, -0.048, -0.04 , -0.032, -0.024, -0.016,
       -0.008,  0.   ,  0.008,  0.016,  0.024,  0.032,  0.04 ,  0.048,
        0.056,  0.064,  0.072,  0.08 ,  0.088,  0.096,  0.104,  0.112,
        0.12 ,  0.128,  0.136,  0.144,  0.152,  0.16 ,  0.168,  0.176,
        0.184,  0.192,  0.2  ,  0.208,  0.216,  0.224,  0.232,  0.24 ,
        0.248,  0.256,  0.264,  0.272,  0.28 ,  0.288,  0.296,  0.304,
        0.312,  0.32 ,  0.328,  0.336,  0.344,  0.352,  0.36 ,  0.368,
        0.376,  0.384,  0.392,  0.4  ,  0.408,  0.416,  0.424,  0.432,
        0.44 ,  0.448,  0.456,  0.464,  0.472,  0.48 ,  0.488,  0.496,
        0.504,  0.512,  0.52 ,  0.528,  0.536,  0.544,  0.552,  0.56 ,
        0.568,  0.576,  0.584,  0.592])

In [9]:
epochs = epochs.crop(tmin=0. , tmax= 0.6)

In [10]:
epochs.times

array([0.   , 0.008, 0.016, 0.024, 0.032, 0.04 , 0.048, 0.056, 0.064,
       0.072, 0.08 , 0.088, 0.096, 0.104, 0.112, 0.12 , 0.128, 0.136,
       0.144, 0.152, 0.16 , 0.168, 0.176, 0.184, 0.192, 0.2  , 0.208,
       0.216, 0.224, 0.232, 0.24 , 0.248, 0.256, 0.264, 0.272, 0.28 ,
       0.288, 0.296, 0.304, 0.312, 0.32 , 0.328, 0.336, 0.344, 0.352,
       0.36 , 0.368, 0.376, 0.384, 0.392, 0.4  , 0.408, 0.416, 0.424,
       0.432, 0.44 , 0.448, 0.456, 0.464, 0.472, 0.48 , 0.488, 0.496,
       0.504, 0.512, 0.52 , 0.528, 0.536, 0.544, 0.552, 0.56 , 0.568,
       0.576, 0.584, 0.592])

In [11]:
epochs.get_data().shape

(45264, 207, 75)

In [12]:
path = './data/'
data_id = 'meg_epochs_12'

import_opt = dict(path=path,
                  data_id=data_id,
                  input_type='trials',
                  target_type='int',
                  n_folds= 5,
                  test_set = 'holdout',
                  overwrite=True,
                  picks={'meg':'grad'},
                  scale=False,
                  crop_baseline=False,
                  decimate=None,
                  )

In [13]:
meta = mneflow.produce_tfrecords(epochs, **import_opt)

processing epochs
Input shapes: X (n, ch, t) :  (45264, 207, 75) y (n, [signal_channels], y_shape) :  (45264, 1) 
 input_type :  trials target_type :  int segment_y :  False
Preprocessing:
n: 45264
Splitting into: 6 folds x 7544
Preprocessed: (45264, 1, 75, 207) (45264, 1) folds: 6 x 7544
Preprocessed targets:  (45264, 1)
Prepocessed sample shape: (1, 75, 207)
Target shape actual/metadata:  (12,) (12,)
Saving TFRecord# 0
Updating: meta.data
Updating: meta.preprocessing


In [14]:
dataset = mneflow.Dataset(meta, train_batch=64)

Updating: meta.data


In [15]:
varcnn_params = dict(n_latent=32,
                  filter_length=7,
                  nonlin = tf.nn.relu,
                  padding = 'SAME',
                  pooling = 2,
                  pool_type='max',
                  dropout = .5,
                  l1_scope = ["weights"],
                  l1=3e-4)

meta.update(model_specs=varcnn_params)

Updating: meta.model_specs


In [16]:
model = mneflow.models.VARCNN(meta)
model.build()

Updating: meta.data
Setting reg for dmx, to l1
Built: dmx input: (None, 1, 75, 207)
input_shape: (None, 1, 75, 32)
Setting reg for tconv, to l1
Built: tconv input: (None, 1, 75, 32)
Setting reg for fc, to l1
Built: fc input: (None, 1, 38, 32)
Input shape: (1, 75, 207)
y_pred: (None, 12)
Initialization complete!


In [17]:
model.train(n_epochs=20, eval_step=50, mode='cv')

Updating: meta.train_params
Class weights:  None
Running cross-validation with 5 folds
fold: 0
Epoch 1/20
50/50 - 5s - 100ms/step - cat_ACC: 0.1375 - loss: 2.8269 - val_cat_ACC: 0.1348 - val_loss: 2.7311
Epoch 2/20
50/50 - 2s - 34ms/step - cat_ACC: 0.1528 - loss: 2.7315 - val_cat_ACC: 0.1655 - val_loss: 2.6920
Epoch 3/20
50/50 - 2s - 33ms/step - cat_ACC: 0.1516 - loss: 2.6837 - val_cat_ACC: 0.1655 - val_loss: 2.6521
Epoch 4/20
50/50 - 2s - 35ms/step - cat_ACC: 0.1391 - loss: 2.6646 - val_cat_ACC: 0.1353 - val_loss: 2.6291
Epoch 5/20
50/50 - 2s - 34ms/step - cat_ACC: 0.1353 - loss: 2.6282 - val_cat_ACC: 0.1655 - val_loss: 2.5969
Epoch 6/20
50/50 - 2s - 31ms/step - cat_ACC: 0.1572 - loss: 2.6078 - val_cat_ACC: 0.1655 - val_loss: 2.5738
Epoch 7/20
50/50 - 2s - 30ms/step - cat_ACC: 0.1388 - loss: 2.5785 - val_cat_ACC: 0.1655 - val_loss: 2.5518
Epoch 8/20
50/50 - 2s - 32ms/step - cat_ACC: 0.1484 - loss: 2.5579 - val_cat_ACC: 0.1655 - val_loss: 2.5378
Epoch 9/20
50/50 - 2s - 30ms/step - cat_

In [19]:
test_loss, test_acc = model.evaluate(meta.data['test_paths'])
print("Test set: Loss = {:.4f} Accuracy = {:.4f}".format(test_loss, test_acc))

Test set: Loss = 2.4023 Accuracy = 0.1351
