In [1]:
import tqdm
import numpy as np
import multiprocessing
import uuid
import os
import threading
import time
import random
import math

In [2]:
from box_client import get_box_client
from annotation_processor import *

In [3]:
import tensorflow as tf
import tensorflow.contrib.layers as layers

In [4]:
seizure_annotations = load_seizure_annotations("seizure_annotations.json")

In [5]:
train_dir = "tmp"

In [6]:
def choose_epoch_files(n_files, files, seizure_annotations, p=0.5):
    n_pos = int(n_files * p)
    n_neg = n_files - n_pos
    
    positive_files = []
    negative_files = []
    for name in files:
        if len(seizure_times_from_npz_filename(name, seizure_annotations)) > 0:
            positive_files.append(name)
        else:
            negative_files.append(name)
    
    positive_examples = list(np.random.choice(positive_files, size=n_pos))
    negative_examples = list(np.random.choice(negative_files, size=n_neg))
    
    return positive_examples + negative_examples

In [7]:
# Function to generate vector y corresponding to  binary classification of video clips
# intervals of duration = window_length annotation in seconds
# seizure array will receive dictionary of video name of seizure times

def ground_truth_label(seizure_array, window_start, window_length):

   # Check if sliding window overlaps with seizure window
   for k in seizure_array:

       # Here just hard-coded 10 sec as minimum duration of seizure
       if (window_start + window_length > k) and (window_start < k + 10):
            return 1
       # Windows after the 10 sec minimum duration of seizure and less than 120 secs after seizure start
       if (window_start >= k + 10) and (window_start < k + 120):
            return -1

   # Return 0 for non-seizure windows
   return 0

In [8]:
def epoch_positive_negative_times(epoch_examples, window_length=5, fps=29.97):
    epoch_file_names = list(epoch_examples.keys())
    pos_example_indices = []
    neg_example_indices = []

    for file_idx, file_name in enumerate(epoch_file_names):
        processed_chunk = epoch_examples[file_name]
        # video times (sec)
        vid_start_time = processed_chunk["start_time"]
        vid_length = int(processed_chunk["features"].shape[0]/fps)
        vid_end_time = vid_start_time + vid_length
        seizure_times = seizure_times_from_npz_filename(file_name, seizure_annotations)
        for i in range(vid_length - window_length):
            label = ground_truth_label(seizure_times, vid_start_time + i, window_length)
            if label == 0:
                neg_example_indices.append((file_idx, i))
            elif label == 1:
                pos_example_indices.append((file_idx, i))
                
    return epoch_file_names, pos_example_indices, neg_example_indices

In [9]:
def indices_to_example(epoch_examples, epoch_filenames, pair, window_length = 5, fps = 30):
    file_index = pair[0]
    time_index = pair[1]
    features = epoch_examples[epoch_filenames[file_index]]['features']
    return features[time_index:time_index+window_length*fps]

In [10]:
def generate_minibatch(batch_size, pos_p, pos_example_indices, neg_example_indices, epoch_examples, epoch_filenames):
    num_pos_examples = int(batch_size*pos_p)
    num_neg_examples = batch_size - num_pos_examples
    
    pos_examples = random.choices(pos_example_indices, k=num_pos_examples)
    neg_examples = random.choices(neg_example_indices, k=num_neg_examples)
    
    batch_x = np.zeros((batch_size, 150, 1536))
    batch_y = np.zeros((batch_size, 1))
    
    
    for i in range(batch_size):
        if i < num_pos_examples:
            pair = pos_examples[i]
            batch_y[i] = 1
        else:
            pair = neg_examples[i - num_pos_examples]
            batch_y[i] = 0
        batch_x[i, :, :] = indices_to_example(epoch_examples, epoch_file_names, pair)
    
    return batch_x, batch_y

In [29]:
all_files = [f for f in os.listdir(train_dir) if os.path.isfile(os.path.join(train_dir, f))]

In [30]:
dev_set = set([
    '2019-01-18_6 up_73-3_MAH00011_8.npz',
    '2019-01-18_6 up_73-3_MAH00013_2.npz',
    '2019-01-18_6 up_74-1_MAH00016_0.npz',
    '2019-01-20_6 up_65-2_MAH00020_8.npz',
    '2019-01-18_6 up_68-2_MAH00014_8.npz',
    '2019-01-18_6 up_68-2_MAH00014_8.npz',
    '2019-01-20_6 up_68-2_MAH00019_6.npz',
    '2019-01-18_6 up_68-2_MAH00012_3.npz',
    '2019-01-18_6 up_70-1_MAH00015_7.npz',
    '2019-01-18_6 up_73-3_MAH00016_6.npz',
    '2019-01-18_6 up_68-2_MAH00014_0.npz',
    '2019-01-18_6 up_68-2_MAH00015_2.npz',
    '2019-01-18_6 up_68-2_MAH00011_8.npz',
    '2019-01-20_6 up_73-3_MAH00020_5.npz',
    '2019-01-18_6 up_68-2_MAH00015_2.npz',
    '2019-01-18_6 up_68-2_MAH00013_3.npz',
    '2019-01-20_6 up_65-2_MAH00021_2.npz',
    '2019-01-18_6 up_73-3_MAH00015_0.npz',
    '2019-01-20_6 up_74-1_MAH00019_8.npz',
    '2019-01-20_6 up_70-1_MAH00018_8.npz',
    '2019-01-20_6 up_74-1_MAH00018_7.npz',
    '2019-01-18_6 up_73-3_MAH00012_6.npz',
    '2019-01-20_6 up_73-2_MAH00018_1.npz',
    '2019-01-18_6 up_73-3_MAH00012_5.npz',
    '2019-01-18_6 up_70-1_MAH00014_2.npz',
    '2019-01-20_6 up_68-2_MAH00018_5.npz',
    '2019-01-18_6 up_70-1_MAH00014_9.npz',
    '2019-01-18_6 up_74-1_MAH00012_6.npz',
    '2019-01-18_6 up_73-2_MAH00011_5.npz',
    '2019-01-18_6 up_70-1_MAH00011_6.npz',
    '2019-01-20_6 up_68-2_MAH00020_5.npz',
    '2019-01-18_6 up_73-3_MAH00015_9.npz',
    '2019-01-18_6 up_73-2_MAH00011_1.npz',
    '2019-01-18_6 up_70-1_MAH00011_6.npz',
    '2019-01-18_6 up_65-2_MAH00014_6.npz',
    '2019-01-20_6 up_68-2_MAH00020_5.npz',
    '2019-01-20_6 up_70-1_MAH00018_6.npz',
    '2019-01-18_6 up_65-2_MAH00011_6.npz',
    '2019-01-18_6 up_73-2_MAH00011_6.npz',
    '2019-01-18_6 up_65-2_MAH00014_8.npz'
])

In [31]:
training_set = [f for f in all_files if f not in dev_set]

In [14]:
# Set up the network
RNN_HIDDEN = 256
INPUT_SIZE = 1536
OUTPUT_SIZE = 1 # 1 bit per timestep
LEARNING_RATE = 0.001
BATCH_SIZE = 128
FALSE_NEG_PEN = 100

In [15]:
inputs  = tf.placeholder(tf.float32, (None, None, INPUT_SIZE))  # (time, batch, in)
outputs = tf.placeholder(tf.float32, (None, OUTPUT_SIZE)) # (batch, out)

cell = tf.nn.rnn_cell.LSTMCell(RNN_HIDDEN, state_is_tuple=True)

In [16]:
initial_state = cell.zero_state(tf.shape(inputs)[0], tf.float32)
rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)

In [21]:
# project output from rnn output size to OUTPUT_SIZE. Sometimes it is worth adding
# an extra layer here.
pred = tf.layers.dense(rnn_states[1], OUTPUT_SIZE, activation=tf.sigmoid)

In [38]:
error = tf.add(tf.multiply(tf.multiply(outputs, -tf.log(pred)), FALSE_NEG_PEN),
               tf.multiply(tf.subtract(1.0, outputs), -tf.log(tf.subtract(1.0, pred))))
error = tf.reduce_mean(error)
train_fn = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(error)

In [39]:
#Training Loop
n_epoch_files = 120
n_epochs = 40
pos_p = 1/2
batch_size = BATCH_SIZE
LEARNING_RATE = 1e-4


with tf.Session() as session:
    session.run(tf.global_variables_initializer())

    for epoch in tqdm.tqdm_notebook(range(n_epochs)):
        epoch_files = choose_epoch_files(n_epoch_files, training_set, seizure_annotations)
        epoch_examples = {}
        for file_name in epoch_files:
            with np.load(os.path.join(train_dir, file_name)) as data:
                example = {}
                example["features"] = data["features"]
                example["start_time"] = data["start_time"]
                epoch_examples[file_name] = example
                

        epoch_file_names, pos_example_indices, neg_example_indices = epoch_positive_negative_times(epoch_examples)
        n_minibatches = int((len(pos_example_indices)/pos_p)/batch_size)
        n_minibatches = 300
        epoch_error = 0.0
        for batch_i in tqdm.tqdm_notebook(range(n_minibatches)):
            batch_X, batch_Y = generate_minibatch(batch_size, pos_p, 
                                                  pos_example_indices, neg_example_indices, 
                                                  epoch_examples, epoch_file_names)
            
            epoch_error += session.run([error, train_fn], {
                inputs: batch_X,
                outputs: batch_Y,
            })[0]
            
        del epoch_examples
        
        dev_examples = {}
        for file_name in dev_set:
            with np.load(os.path.join(train_dir, file_name)) as data:
                example = {}
                example["features"] = data["features"]
                example["start_time"] = data["start_time"]
                dev_examples[file_name] = example
        
        epoch_error /= n_minibatches        
        print("Epoch %d, train error: %.4f" % (epoch+1, epoch_error))
                
        dev_file_names, pos_example_indices, neg_example_indices = epoch_positive_negative_times(dev_examples)
        dev_set_error = 0.0
        n_dev_examples = len(pos_example_indices) + len(neg_example_indices)
        n_dev_batches = math.ceil(n_dev_examples/BATCH_SIZE)
        
        if (epoch+1) % 4 == 0:
            conf_matrix = np.zeros((2,2))
            for batch_i in tqdm.tqdm_notebook(range(n_dev_batches)):
                offset = batch_i * BATCH_SIZE
                size = min(BATCH_SIZE, n_dev_examples - offset) 

                batch_X = np.zeros((size, 150, 1536))
                batch_Y = np.zeros((size, 1))


                for i in range(size):
                    if i + offset < len(pos_example_indices):
                        pair = pos_example_indices[i + offset]
                        batch_Y[i] = 1
                    else:
                        pair = neg_example_indices[i + offset - len(pos_example_indices)]
                        batch_Y[i] = 0
                    batch_X[i, :, :] = indices_to_example(dev_examples, dev_file_names, pair)

                predictions, err = session.run([pred, error], {
                    inputs: batch_X,
                    outputs: batch_Y
                })
                predictions = predictions.flatten()
                batch_Y = batch_Y.flatten()

                tp = np.sum((predictions >= 0.5) & (batch_Y == 1))
                fp = np.sum((predictions >= 0.5) & (batch_Y == 0))
                tn = np.sum((predictions < 0.5) & (batch_Y == 0))
                fn = np.sum((predictions < 0.5) & (batch_Y == 1))
                dev_set_error += err
                conf_matrix[0, 0] += tn
                conf_matrix[1, 0] += fn
                conf_matrix[0, 1] += fp
                conf_matrix[1, 1] += tp
        
            dev_set_error /= n_dev_batches       
            print("Dev set error: %.4f" % (dev_set_error))
            print(conf_matrix)

HBox(children=(IntProgress(value=0, max=40), HTML(value='')))

HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 1, train error: 2.3510


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 2, train error: 1.7388


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 3, train error: 1.7973


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 4, train error: 1.5536


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 2.6546
[[9812. 8728.]
 [  57.  181.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 5, train error: 1.2844


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 6, train error: 1.2569


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 7, train error: 1.6124


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 8, train error: 1.5317


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 3.3828
[[ 2324. 16216.]
 [    0.   238.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 9, train error: 1.2153


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 10, train error: 0.8378


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 11, train error: 0.6261


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 12, train error: 1.0797


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 3.7552
[[16432.  2108.]
 [  190.    48.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 13, train error: 0.8920


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 14, train error: 0.8779


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 15, train error: 0.4653


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 16, train error: 0.8782


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 5.5174
[[17322.  1218.]
 [  215.    23.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 17, train error: 0.5151


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 18, train error: 1.6992


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 19, train error: 1.7525


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 20, train error: 1.1248


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 5.9427
[[    0. 18540.]
 [    0.   238.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 21, train error: 1.6783


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 22, train error: 1.1169


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 23, train error: 0.4466


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 24, train error: 0.8778


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 4.6005
[[1.7144e+04 1.3960e+03]
 [2.2600e+02 1.2000e+01]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 25, train error: 0.3311


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 26, train error: 0.1942


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 27, train error: 1.1594


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 28, train error: 0.2077


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 5.8671
[[1.8031e+04 5.0900e+02]
 [2.3000e+02 8.0000e+00]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 29, train error: 0.1294


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 30, train error: 0.9465


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 31, train error: 0.2952


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 32, train error: 1.4425


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 3.5925
[[ 2430. 16110.]
 [    0.   238.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 33, train error: 1.5581


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 34, train error: 0.5215


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 35, train error: 0.5972


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 36, train error: 0.3546


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 7.0663
[[1.7402e+04 1.1380e+03]
 [2.3100e+02 7.0000e+00]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 37, train error: 0.8106


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 38, train error: 1.3687


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 39, train error: 0.5217


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 40, train error: 0.4628


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 6.4976
[[1.8217e+04 3.2300e+02]
 [2.3300e+02 5.0000e+00]]


In [40]:
#Training Loop
n_epoch_files = 120
n_epochs = 40
pos_p = 1/2
batch_size = BATCH_SIZE
LEARNING_RATE = 1e-5

with tf.Session() as session:
    session.run(tf.global_variables_initializer())

    for epoch in tqdm.tqdm_notebook(range(n_epochs)):
        epoch_files = choose_epoch_files(n_epoch_files, training_set, seizure_annotations)
        epoch_examples = {}
        for file_name in epoch_files:
            with np.load(os.path.join(train_dir, file_name)) as data:
                example = {}
                example["features"] = data["features"]
                example["start_time"] = data["start_time"]
                epoch_examples[file_name] = example
                

        epoch_file_names, pos_example_indices, neg_example_indices = epoch_positive_negative_times(epoch_examples)
        n_minibatches = int((len(pos_example_indices)/pos_p)/batch_size)
        n_minibatches = 300
        epoch_error = 0.0
        for batch_i in tqdm.tqdm_notebook(range(n_minibatches)):
            batch_X, batch_Y = generate_minibatch(batch_size, pos_p, 
                                                  pos_example_indices, neg_example_indices, 
                                                  epoch_examples, epoch_file_names)
            
            epoch_error += session.run([error, train_fn], {
                inputs: batch_X,
                outputs: batch_Y,
            })[0]
            
        del epoch_examples
        
        dev_examples = {}
        for file_name in dev_set:
            with np.load(os.path.join(train_dir, file_name)) as data:
                example = {}
                example["features"] = data["features"]
                example["start_time"] = data["start_time"]
                dev_examples[file_name] = example
        
        epoch_error /= n_minibatches        
        print("Epoch %d, train error: %.4f" % (epoch+1, epoch_error))
                
        dev_file_names, pos_example_indices, neg_example_indices = epoch_positive_negative_times(dev_examples)
        dev_set_error = 0.0
        n_dev_examples = len(pos_example_indices) + len(neg_example_indices)
        n_dev_batches = math.ceil(n_dev_examples/BATCH_SIZE)
        
        if (epoch+1) % 4 == 0:
            conf_matrix = np.zeros((2,2))
            for batch_i in tqdm.tqdm_notebook(range(n_dev_batches)):
                offset = batch_i * BATCH_SIZE
                size = min(BATCH_SIZE, n_dev_examples - offset) 

                batch_X = np.zeros((size, 150, 1536))
                batch_Y = np.zeros((size, 1))


                for i in range(size):
                    if i + offset < len(pos_example_indices):
                        pair = pos_example_indices[i + offset]
                        batch_Y[i] = 1
                    else:
                        pair = neg_example_indices[i + offset - len(pos_example_indices)]
                        batch_Y[i] = 0
                    batch_X[i, :, :] = indices_to_example(dev_examples, dev_file_names, pair)

                predictions, err = session.run([pred, error], {
                    inputs: batch_X,
                    outputs: batch_Y
                })
                predictions = predictions.flatten()
                batch_Y = batch_Y.flatten()

                tp = np.sum((predictions >= 0.5) & (batch_Y == 1))
                fp = np.sum((predictions >= 0.5) & (batch_Y == 0))
                tn = np.sum((predictions < 0.5) & (batch_Y == 0))
                fn = np.sum((predictions < 0.5) & (batch_Y == 1))
                dev_set_error += err
                conf_matrix[0, 0] += tn
                conf_matrix[1, 0] += fn
                conf_matrix[0, 1] += fp
                conf_matrix[1, 1] += tp
        
            dev_set_error /= n_dev_batches       
            print("Dev set error: %.4f" % (dev_set_error))
            print(conf_matrix)

HBox(children=(IntProgress(value=0, max=40), HTML(value='')))

HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 1, train error: 2.3626


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 2, train error: 1.3715


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 3, train error: 1.6009


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 4, train error: 1.5949


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 2.3511
[[10164.  8376.]
 [   95.   143.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 5, train error: 1.3018


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 6, train error: 0.9306


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 7, train error: 1.4168


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 8, train error: 2.0393


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 2.3607
[[ 7072. 11468.]
 [   28.   210.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 9, train error: 1.3433


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 10, train error: 1.1946


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 11, train error: 0.7677


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 12, train error: 1.2902


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 2.3999
[[10786.  7754.]
 [  122.   116.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 13, train error: 1.7118


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 14, train error: 1.1031


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 15, train error: 0.8556


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 16, train error: 0.8028


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 5.8584
[[17449.  1091.]
 [  220.    18.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 17, train error: 1.6015


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 18, train error: 1.0156


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 19, train error: 0.7276


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 20, train error: 0.7105


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 9.3944
[[1.5000e+01 1.8525e+04]
 [0.0000e+00 2.3800e+02]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 21, train error: 1.9818


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 22, train error: 1.1044


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 23, train error: 0.6929


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 24, train error: 0.4160


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 6.4822
[[1.7817e+04 7.2300e+02]
 [2.3200e+02 6.0000e+00]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 25, train error: 0.7271


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 26, train error: 2.8825


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 27, train error: 1.4400


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 28, train error: 0.5973


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 3.3597
[[8780. 9760.]
 [  42.  196.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 29, train error: 0.7508


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 30, train error: 1.0399


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 31, train error: 0.4429


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 32, train error: 0.1984


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 6.3401
[[1.8067e+04 4.7300e+02]
 [2.3400e+02 4.0000e+00]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 33, train error: 0.8856


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 34, train error: 0.2871


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 35, train error: 0.2895


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 36, train error: 2.1641


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 2.7986
[[ 5721. 12819.]
 [   14.   224.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 37, train error: 1.3038


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 38, train error: 0.3127


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 39, train error: 2.4062


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 40, train error: 1.4641


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 2.3155
[[9263. 9277.]
 [  79.  159.]]


In [41]:
error = tf.add(tf.multiply(tf.multiply(outputs, -tf.log(pred)), 10.0),
               tf.multiply(tf.subtract(1.0, outputs), -tf.log(tf.subtract(1.0, pred))))
error = tf.reduce_mean(error)
train_fn = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(error)

In [42]:
#Training Loop
n_epoch_files = 120
n_epochs = 40
pos_p = 1/2
batch_size = BATCH_SIZE
LEARNING_RATE = 1e-4

with tf.Session() as session:
    session.run(tf.global_variables_initializer())

    for epoch in tqdm.tqdm_notebook(range(n_epochs)):
        epoch_files = choose_epoch_files(n_epoch_files, training_set, seizure_annotations)
        epoch_examples = {}
        for file_name in epoch_files:
            with np.load(os.path.join(train_dir, file_name)) as data:
                example = {}
                example["features"] = data["features"]
                example["start_time"] = data["start_time"]
                epoch_examples[file_name] = example
                

        epoch_file_names, pos_example_indices, neg_example_indices = epoch_positive_negative_times(epoch_examples)
        n_minibatches = int((len(pos_example_indices)/pos_p)/batch_size)
        n_minibatches = 300
        epoch_error = 0.0
        for batch_i in tqdm.tqdm_notebook(range(n_minibatches)):
            batch_X, batch_Y = generate_minibatch(batch_size, pos_p, 
                                                  pos_example_indices, neg_example_indices, 
                                                  epoch_examples, epoch_file_names)
            
            epoch_error += session.run([error, train_fn], {
                inputs: batch_X,
                outputs: batch_Y,
            })[0]
            
        del epoch_examples
        
        dev_examples = {}
        for file_name in dev_set:
            with np.load(os.path.join(train_dir, file_name)) as data:
                example = {}
                example["features"] = data["features"]
                example["start_time"] = data["start_time"]
                dev_examples[file_name] = example
        
        epoch_error /= n_minibatches        
        print("Epoch %d, train error: %.4f" % (epoch+1, epoch_error))
                
        dev_file_names, pos_example_indices, neg_example_indices = epoch_positive_negative_times(dev_examples)
        dev_set_error = 0.0
        n_dev_examples = len(pos_example_indices) + len(neg_example_indices)
        n_dev_batches = math.ceil(n_dev_examples/BATCH_SIZE)
        
        if (epoch+1) % 4 == 0:
            conf_matrix = np.zeros((2,2))
            for batch_i in tqdm.tqdm_notebook(range(n_dev_batches)):
                offset = batch_i * BATCH_SIZE
                size = min(BATCH_SIZE, n_dev_examples - offset) 

                batch_X = np.zeros((size, 150, 1536))
                batch_Y = np.zeros((size, 1))


                for i in range(size):
                    if i + offset < len(pos_example_indices):
                        pair = pos_example_indices[i + offset]
                        batch_Y[i] = 1
                    else:
                        pair = neg_example_indices[i + offset - len(pos_example_indices)]
                        batch_Y[i] = 0
                    batch_X[i, :, :] = indices_to_example(dev_examples, dev_file_names, pair)

                predictions, err = session.run([pred, error], {
                    inputs: batch_X,
                    outputs: batch_Y
                })
                predictions = predictions.flatten()
                batch_Y = batch_Y.flatten()

                tp = np.sum((predictions >= 0.5) & (batch_Y == 1))
                fp = np.sum((predictions >= 0.5) & (batch_Y == 0))
                tn = np.sum((predictions < 0.5) & (batch_Y == 0))
                fn = np.sum((predictions < 0.5) & (batch_Y == 1))
                dev_set_error += err
                conf_matrix[0, 0] += tn
                conf_matrix[1, 0] += fn
                conf_matrix[0, 1] += fp
                conf_matrix[1, 1] += tp
        
            dev_set_error /= n_dev_batches       
            print("Dev set error: %.4f" % (dev_set_error))
            print(conf_matrix)

HBox(children=(IntProgress(value=0, max=40), HTML(value='')))

HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 1, train error: 1.0733


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 2, train error: 0.6652


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 3, train error: 0.4472


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 4, train error: 0.3867


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.6005
[[17041.  1499.]
 [  210.    28.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 5, train error: 0.2321


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 6, train error: 0.3070


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 7, train error: 0.5956


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 8, train error: 0.2922


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.7894
[[15158.  3382.]
 [  209.    29.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 9, train error: 0.2040


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 10, train error: 0.3231


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 11, train error: 0.1830


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 12, train error: 0.1650


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.6246
[[18218.   322.]
 [  238.     0.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 13, train error: 0.1397


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 14, train error: 0.2682


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 15, train error: 0.2019


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 16, train error: 0.1282


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.7346
[[18226.   314.]
 [  238.     0.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 17, train error: 0.1277


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 18, train error: 0.1653


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 19, train error: 0.1427


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 20, train error: 0.1092


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.7561
[[18284.   256.]
 [  238.     0.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 21, train error: 0.0566


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 22, train error: 0.0944


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 23, train error: 0.3648


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 24, train error: 0.1677


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.6704
[[1.7708e+04 8.3200e+02]
 [2.2400e+02 1.4000e+01]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 25, train error: 0.1647


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 26, train error: 0.1467


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 27, train error: 0.0550


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 28, train error: 0.1277


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.8027
[[1.8142e+04 3.9800e+02]
 [2.3100e+02 7.0000e+00]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 29, train error: 0.0867


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 30, train error: 0.0439


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 31, train error: 0.0508


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 32, train error: 0.0303


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.8799
[[1.8119e+04 4.2100e+02]
 [2.2600e+02 1.2000e+01]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 33, train error: 0.1200


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 34, train error: 0.0487


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 35, train error: 0.0379


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 36, train error: 0.1232


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.7526
[[1.7655e+04 8.8500e+02]
 [2.2700e+02 1.1000e+01]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 37, train error: 0.0713


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 38, train error: 0.1490


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 39, train error: 0.0497


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 40, train error: 0.1166


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.9182
[[1.8091e+04 4.4900e+02]
 [2.3000e+02 8.0000e+00]]


In [43]:
#Training Loop
n_epoch_files = 120
n_epochs = 40
pos_p = 1/2
batch_size = BATCH_SIZE
LEARNING_RATE = 1e-5

with tf.Session() as session:
    session.run(tf.global_variables_initializer())

    for epoch in tqdm.tqdm_notebook(range(n_epochs)):
        epoch_files = choose_epoch_files(n_epoch_files, training_set, seizure_annotations)
        epoch_examples = {}
        for file_name in epoch_files:
            with np.load(os.path.join(train_dir, file_name)) as data:
                example = {}
                example["features"] = data["features"]
                example["start_time"] = data["start_time"]
                epoch_examples[file_name] = example
                

        epoch_file_names, pos_example_indices, neg_example_indices = epoch_positive_negative_times(epoch_examples)
        n_minibatches = int((len(pos_example_indices)/pos_p)/batch_size)
        n_minibatches = 300
        epoch_error = 0.0
        for batch_i in tqdm.tqdm_notebook(range(n_minibatches)):
            batch_X, batch_Y = generate_minibatch(batch_size, pos_p, 
                                                  pos_example_indices, neg_example_indices, 
                                                  epoch_examples, epoch_file_names)
            
            epoch_error += session.run([error, train_fn], {
                inputs: batch_X,
                outputs: batch_Y,
            })[0]
            
        del epoch_examples
        
        dev_examples = {}
        for file_name in dev_set:
            with np.load(os.path.join(train_dir, file_name)) as data:
                example = {}
                example["features"] = data["features"]
                example["start_time"] = data["start_time"]
                dev_examples[file_name] = example
        
        epoch_error /= n_minibatches        
        print("Epoch %d, train error: %.4f" % (epoch+1, epoch_error))
                
        dev_file_names, pos_example_indices, neg_example_indices = epoch_positive_negative_times(dev_examples)
        dev_set_error = 0.0
        n_dev_examples = len(pos_example_indices) + len(neg_example_indices)
        n_dev_batches = math.ceil(n_dev_examples/BATCH_SIZE)
        
        if (epoch+1) % 4 == 0:
            conf_matrix = np.zeros((2,2))
            for batch_i in tqdm.tqdm_notebook(range(n_dev_batches)):
                offset = batch_i * BATCH_SIZE
                size = min(BATCH_SIZE, n_dev_examples - offset) 

                batch_X = np.zeros((size, 150, 1536))
                batch_Y = np.zeros((size, 1))


                for i in range(size):
                    if i + offset < len(pos_example_indices):
                        pair = pos_example_indices[i + offset]
                        batch_Y[i] = 1
                    else:
                        pair = neg_example_indices[i + offset - len(pos_example_indices)]
                        batch_Y[i] = 0
                    batch_X[i, :, :] = indices_to_example(dev_examples, dev_file_names, pair)

                predictions, err = session.run([pred, error], {
                    inputs: batch_X,
                    outputs: batch_Y
                })
                predictions = predictions.flatten()
                batch_Y = batch_Y.flatten()

                tp = np.sum((predictions >= 0.5) & (batch_Y == 1))
                fp = np.sum((predictions >= 0.5) & (batch_Y == 0))
                tn = np.sum((predictions < 0.5) & (batch_Y == 0))
                fn = np.sum((predictions < 0.5) & (batch_Y == 1))
                dev_set_error += err
                conf_matrix[0, 0] += tn
                conf_matrix[1, 0] += fn
                conf_matrix[0, 1] += fp
                conf_matrix[1, 1] += tp
        
            dev_set_error /= n_dev_batches       
            print("Dev set error: %.4f" % (dev_set_error))
            print(conf_matrix)

HBox(children=(IntProgress(value=0, max=40), HTML(value='')))

HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 1, train error: 1.5428


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 2, train error: 0.6946


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 3, train error: 0.3747


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 4, train error: 0.2257


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.6824
[[14463.  4077.]
 [  133.   105.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 5, train error: 0.2828


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 6, train error: 0.2533


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 7, train error: 0.1894


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 8, train error: 0.1398


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.6439
[[17999.   541.]
 [  238.     0.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 9, train error: 0.0863


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 10, train error: 0.2138


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 11, train error: 0.0916


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 12, train error: 0.0859


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 1.2881
[[1.5241e+04 3.2990e+03]
 [2.2400e+02 1.4000e+01]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 13, train error: 0.2329


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 14, train error: 0.0653


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 15, train error: 0.4444


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 16, train error: 0.1933


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.6760
[[1.722e+04 1.320e+03]
 [2.340e+02 4.000e+00]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 17, train error: 0.1165


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 18, train error: 0.1272


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 19, train error: 0.2161


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 20, train error: 0.1280


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.6643
[[18156.   384.]
 [  238.     0.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 21, train error: 0.1173


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 22, train error: 0.0460


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 23, train error: 0.0455


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 24, train error: 0.0489


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.6831
[[18223.   317.]
 [  238.     0.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 25, train error: 0.0751


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 26, train error: 0.0706


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 27, train error: 0.0456


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 28, train error: 0.0606


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.6744
[[18303.   237.]
 [  238.     0.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 29, train error: 0.0317


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 30, train error: 0.0358


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 31, train error: 0.0418


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 32, train error: 0.0616


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.7124
[[18505.    35.]
 [  238.     0.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 33, train error: 0.0473


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 34, train error: 0.0603


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 35, train error: 0.1076


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 36, train error: 0.0830


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.7718
[[18445.    95.]
 [  238.     0.]]


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 37, train error: 0.0584


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 38, train error: 0.0853


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 39, train error: 0.0343


HBox(children=(IntProgress(value=0, max=300), HTML(value='')))

Epoch 40, train error: 0.0824


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

Dev set error: 0.7971
[[18417.   123.]
 [  238.     0.]]
