In [None]:
from __future__ import print_function
import numpy as np
from six.moves import cPickle as pickle
import glob
import time
import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf
slim = tf.contrib.slim
from tensorflow.python.client import timeline

print("import done")

## EEGNET implementation

Part of https://arxiv.org/pdf/1609.03499.pdf that most concerns classification:
"As a last experiment we looked at speech recognition with WaveNets on the TIMIT (Garofolo et al., 1993) dataset. For this task we added a mean-pooling layer after the dilation convolutions that aggregated the activations to coarser frames spanning 10 milliseconds (160 x downsampling). The pooling layer was followed by a few non-causal convolutions. We trained WaveNet with two loss terms, one to predict the next sample and one to classify the frame, the model generalized better than with a single loss and achieved 18.8 PER on the test set, which is to our knowledge the best score obtained from a model trained directly on raw audio on TIMIT."

Look into: http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43290.pdf
"Input: This layer extracts 275 ms waveform segments from each of M input microphones. Successive inputs are hopped by 10ms. At the 16kHz sampling rate used in our experiments each segment contains M X 4401 dimensions."
....

In [None]:
_FILE_NUM_POINTS = 240000
_CHANNELS = 16
_NUM_LABELS = 2
_NUM_SPLITS = 100
_SIGMA_THRESHOLD = 1.0
_BATCH_SIZE = 16
_BATCH_NUM_POINTS = _FILE_NUM_POINTS/_NUM_SPLITS
print(_BATCH_NUM_POINTS)

def preprocess_dataset(data, label):
    # Split data into smaller segments (speeds up trainning)
    data = tf.reshape(data, shape=[_FILE_NUM_POINTS, _CHANNELS])
    data = tf.pack(tf.split(0, _NUM_SPLITS, data), axis=0)
    # Remove dropout segments
    _, var = tf.nn.moments(data, axes=[1, 2])
    # 'tf.where' returns a 2D Tensor. reshape it to 1D.
    idx_clean = tf.reshape(tf.where(tf.greater(var, _SIGMA_THRESHOLD)), shape=[-1])
    # gather from data only indexes > sigma threshold
    data = tf.gather(data, idx_clean)
    # Create label array of segments
    label = tf.one_hot(label, _NUM_LABELS, dtype=tf.int32)
    num_segments = tf.shape(data)[0]
    label = tf.reshape(tf.tile(label, [num_segments]), shape=[num_segments, _NUM_LABELS])
    # Normalize mean=0 and sigma=0.25
    data_mean = tf.expand_dims(tf.reduce_mean(data, reduction_indices=[1]), dim=1)
    data = tf.sub(data, data_mean)
    data_max = tf.expand_dims(tf.reduce_max(tf.abs(data), reduction_indices=[1]), dim=1)
    data = tf.div(data, tf.mul(4.0, data_max))
    # 4D tensor with height = 1: [batch, height, width, channels]
    data = tf.expand_dims(data, dim=1)
    return data, label
    

def read_dataset(folder):
    filenames = glob.glob(folder)
    print("Loading #%d files."%len(filenames))

    reader = tf.TFRecordReader

    keys_to_features = {
        'data': tf.FixedLenFeature([_FILE_NUM_POINTS*_CHANNELS], tf.float32),
        'label': tf.FixedLenFeature([], tf.int64),
        #'filename': tf.FixedLenFeature([], tf.string),
    }
    items_to_handlers = {
        'data': slim.tfexample_decoder.Tensor('data'), 
        'label': slim.tfexample_decoder.Tensor('label'), 
        #'filename': slim.tfexample_decoder.Tensor('filename'), 
    }    
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)

    items_to_descriptions = {
        'data': '240000 sample points of iEEG.',
        'label': 'Label 0 indicates interictal and 1 preictal.', 
        #'filename': 'File name containing the data',
    }

    dataset = slim.dataset.Dataset(
        data_sources=filenames, 
        reader=reader, 
        decoder=decoder, 
        num_samples=1, 
        items_to_descriptions=items_to_descriptions)

    data_provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset, shuffle=True, num_epochs=None, common_queue_capacity=16, common_queue_min=1)

    data, label = data_provider.get(['data', 'label'])

    ## Preprocess
    data, label = preprocess_dataset(data, label)

    ## Batch it up.
    data, label = tf.train.shuffle_batch([data, label], 
                                         batch_size=_BATCH_SIZE, 
                                         capacity=40*_NUM_SPLITS, 
                                         min_after_dequeue=20*_NUM_SPLITS, 
                                         num_threads=1, 
                                         enqueue_many=True)
    return data, label

In [None]:
#How many filters to learn for the input.
input_channels=16
#How many filters to learn for the residual.
residual_channels=2*input_channels
# size after pooling layer
pool_size = 2400
# convolution filters width
filter_width=3

def network(batch_data, reuse=False, is_training=True):
    with tf.variable_scope('eegnet_network', reuse=reuse):
        with slim.arg_scope([slim.batch_norm], 
                            is_training=is_training):
            with slim.arg_scope([slim.conv2d, slim.fully_connected], 
                                weights_initializer=slim.xavier_initializer(), 
                                normalizer_fn=slim.batch_norm):
                with tf.variable_scope('input_layer'):
                    hidden = slim.conv2d(batch_data, residual_channels, [1, filter_width], stride=1, rate=1, 
                                         activation_fn=None, scope='conv1')

                with tf.variable_scope('hidden'):
                    with tf.variable_scope('layer1'):
                        layer_input = hidden
                        hidden = slim.conv2d(hidden, 2*residual_channels, [1, filter_width], stride=1, rate=2, 
                                             activation_fn=None, scope='dilconv')
                        filtr, gate = tf.split(3, 2, hidden) # split features in half
                        hidden = tf.mul(tf.tanh(filtr), tf.sigmoid(gate), name='filterXgate')
                        hidden = slim.conv2d(hidden, residual_channels, 1, activation_fn=None, scope='1x1skip')
                        skip = hidden # skip conn
                        hidden = tf.add(hidden, layer_input) # residual conn
                    with tf.variable_scope('layer2'):
                        layer_input = hidden
                        hidden = slim.conv2d(hidden, 2*residual_channels, [1, filter_width], stride=1, rate=4, 
                                             activation_fn=None, scope='dilconv')
                        filtr, gate = tf.split(3, 2, hidden) # split features in half
                        hidden = tf.mul(tf.tanh(filtr), tf.sigmoid(gate), name='filterXgate')
                        hidden = slim.conv2d(hidden, residual_channels, 1, activation_fn=None, scope='1x1skip')
                        skip = tf.add(skip, hidden) # skip conn
                        hidden = tf.add(hidden, layer_input) # residual conn
                    with tf.variable_scope('layer3'):
                        hidden = slim.conv2d(hidden, 2*residual_channels, [1, filter_width], stride=1, rate=8, 
                                             activation_fn=None, scope='dilconv')
                        filtr, gate = tf.split(3, 2, hidden) # split features in half
                        hidden = tf.mul(tf.tanh(filtr), tf.sigmoid(gate), name='filterXgate')
                        hidden = slim.conv2d(hidden, residual_channels, 1, activation_fn=None, scope='1x1skip')
                        skip = tf.add(skip, hidden) # skip conn

                with tf.variable_scope('skip_processing'):
                    hidden = tf.nn.relu(skip)
                    hidden = slim.avg_pool2d(hidden, 
                                             [1, _BATCH_NUM_POINTS*2//pool_size], 
                                             [1, _BATCH_NUM_POINTS//pool_size])
                    # 1 x 2400 x residual_channels
                    hidden = slim.conv2d(hidden, 32, 1, activation_fn=tf.nn.relu, scope='1x1compress1')
                    hidden = slim.conv2d(hidden, 16, [1, 8], stride=4, activation_fn=tf.nn.relu, scope='1x5reduce1')
                    # 1 x 600 x 16
                    hidden = slim.conv2d(hidden, 8, 1, activation_fn=tf.nn.relu, scope='1x1compress2')
                    hidden = slim.conv2d(hidden, 4, [1, 8], stride=4, activation_fn=tf.nn.relu, scope='1x5reduce2')
                    # 1 x 150 x 4
                    hidden = slim.conv2d(hidden, 2, 1, activation_fn=tf.nn.relu, scope='1x1compress3')
                    hidden = slim.conv2d(hidden, 2, [1, 6], stride=3, activation_fn=tf.nn.relu, scope='1x5reduce3')
                    # 1 x 75 x 2

                with tf.variable_scope('logits'):
                    hidden = slim.dropout(hidden, 0.7, is_training=is_training)
                    hidden = slim.flatten(hidden)
                    logits = slim.fully_connected(hidden, _NUM_LABELS, activation_fn=None, 
                                                  normalizer_fn=None, scope='fc1')
    return logits

In [None]:
#Construct computation graph
graph = tf.Graph()

with graph.as_default():
    # Input pipeline
    train_data, train_labels = read_dataset('./dataset_small/*.tfr')

    with tf.name_scope('eegnet_handling'):
        logits = network(train_data)
        loss = slim.losses.softmax_cross_entropy(logits, train_labels, scope='loss')
        tf.scalar_summary('loss', loss)
        optimizer = tf.train.AdamOptimizer(
            learning_rate=1e-3, epsilon=1e-4).minimize(loss, var_list=tf.trainable_variables())
        train_probabilities = tf.nn.softmax(logits)
        train_predictions = tf.one_hot(tf.argmax(train_probabilities, 1), _NUM_LABELS, dtype=tf.int32)
        train_accuracy = slim.metrics.accuracy(train_predictions, train_labels, 100.0)

    init_op = tf.group(tf.initialize_all_variables(), 
                       tf.initialize_local_variables())
    
    # Add histograms for trainable variables.
    for var in tf.trainable_variables():
        tf.histogram_summary(var.op.name, var)
        
    # Add summaries for activations: NOT WORKING YET. TF ERROR.
    #slim.summarize_activations()
    
    #Merge all summaries and write to a folder
    merged_summs = tf.merge_all_summaries()
    results_writer = tf.train.SummaryWriter('./results', graph)
    
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    #tracing for timeline
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()    
    
print('computational graph created')

In [None]:
num_steps = 5001

trace_file = open('./tracing/timeline.json', 'w')
save_path = './checkpoints/model.ckpt'

best_loss = 99.0
val_accu = 0.0
best_val_accu = 0.0
t = 0
elapt = 0

with tf.Session(graph=graph) as sess:
    ttotal = time.time()
    with slim.queues.QueueRunners(sess):        
        init_op.run()
        print('Initialized')
        for step in range(num_steps):
            t = time.time()
            _, l, trlabels, trprob, traccu, summary = sess.run(
                [optimizer, loss, train_labels, train_probabilities, train_accuracy, merged_summs])
            results_writer.add_summary(summary, step)
            elapt = time.time()
            if (step % 49 == 0):
                best_loss = l if l < best_loss else best_loss
                print('Minibatch total loss at step %d: %f' % (step, l), '| Best:', best_loss)
                print('Minibatch accuracy:', traccu)
                print('Predictions | Labels:\n', np.concatenate((trprob[:2], trlabels[:2]), axis=1))
                print('Last iter time:', elapt-t)
            #if (step % 50 == 0):
            #    val_accu = valid_accuracy.eval()
            #    best_val_accu = val_accu if val_accu > best_val_accu else best_val_accu
            #    print('###-> Validation accuracy:', val_accu, '| Best:', best_val_accu)
    
    ettotal = time.time()
    print('Total time: %f hours' %((ettotal-ttotal)/3600.0))
            
    # Save tracing into disl
    #trace = timeline.Timeline(step_stats=run_metadata.step_stats)
    #trace_file.write(trace.generate_chrome_trace_format(show_memory=True))
            
    # Save the variables to disk.
    saver.save(sess, save_path)
    print("Model saved in file: %s" % save_path)

    print('Finished training')