## TFRecords check

This code checks whether the recovered / decoded data structure is the same data pre- and post-conversion to TFRecords format.

In [1]:
import numpy as np
import tensorflow as tf
import pathlib
from functools import partial

### Encoding functions

In [2]:
# TF encode / decode functionality

def _float_feature_seq(eeg_channel):
    """ Convert sequence of EEG values to tf.train.FeatureList """
    feature_list = tf.train.FeatureList(feature=[
        tf.train.Feature(float_list=tf.train.FloatList(
            value=eeg_channel))])
    
    return feature_list

def _int_feature_scalar(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def encode_single_example(eeg_data: np.array, label: int):
    """ TFRecords: process single EEG trial as a SequenceExample"""
    #print(eeg_data.shape)
    eeg_data_flat = eeg_data.copy().reshape((-1, 1))
    
    fl_dict = {'eeg_data':_float_feature_seq(eeg_data_flat)}
    label_dict = {'label':  _int_feature_scalar(label)}
    
    label_features = tf.train.Features(feature=label_dict)
    
    #print(f'label_features : {label_features}')
    
    feature_lists = tf.train.FeatureLists(feature_list=fl_dict)
    protobuff = tf.train.SequenceExample(context=label_features,
                                         feature_lists=feature_lists)

    protobuff_serialised = protobuff.SerializeToString()
    
    return protobuff_serialised

def convert(data, labels, n_per_tfr, fname=None):
    """ Convert NumPY EEG data to TFRecords
    
    Args:
        data:      NumPy array of EEG data (shape = (n_batch, n_chan, n_time))
        labels:    Accompanying labels for each row in `data`
        n_per_tf:  How many EEG samples to include per TFRecord
        fname:     Filename for saved TFRecords
        
    Returns:
        None
    """
    
    max_range = (len(data) // n_per_tfr)
    idx = [n_per_tfr * i for i in range(max_range+1)]
    loop_idx = list(zip(idx, idx[1:]))

    for i, (start, stop) in enumerate(loop_idx, start=1):
        print(f'Iteration {i}/{len(loop_idx)}, start={start}, stop={stop}')

        X = data[start:stop]
        y = labels[start:stop]

        file_path = f'{fname}_file{i}.tfrecords'

        with tf.io.TFRecordWriter(file_path) as writer:
            for sample, label in zip(X, y):
                serialised_example = encode_single_example(sample, label)                              
                writer.write(serialised_example)

            writer.close()

### Decoding function (NO Transpose)

In [3]:
def decode_single_example(serialised_example, START_WINDOW, END_WINDOW,
                           TOTAL_TIMEPOINTS, N_ELECTRODES):

    data_dim = TOTAL_TIMEPOINTS*N_ELECTRODES
    #print(f'expected data_dim = {TOTAL_TIMEPOINTS}*64={data_dim}')

    context_desc = {'label': tf.io.FixedLenFeature([], dtype=tf.int64)}
    feature_desc = {
        'eeg_data': tf.io.FixedLenSequenceFeature([data_dim], dtype=tf.float32)
    }

    context, data = tf.io.parse_single_sequence_example(
        serialized=serialised_example,
        context_features=context_desc,
        sequence_features=feature_desc,
        name='parsing_single_seq_example')

    data = data['eeg_data']
    data = tf.reshape(data, (N_ELECTRODES, TOTAL_TIMEPOINTS))
    #data = tf.transpose(data, (1,0))

    data = data[:, START_WINDOW:END_WINDOW] # Extract (potential sub-window)
    WINDOW_LENGTH = END_WINDOW - START_WINDOW
    
    # If using SVM and need to flatten again
    #data = tf.reshape(data, (1, WINDOW_LENGTH * N_ELECTRODES))

    label = context['label']
    label = tf.cast(label, tf.int32)

    return data, label

### Read TFRecords and return dataset

In [4]:
def get_dataset(files, BATCH_SIZE, repeat,
               START_WINDOW, END_WINDOW, TOTAL_TIMEPOINTS, n_electrodes):
    
    decode_single_example_fn = partial(decode_single_example,
                                       START_WINDOW=START_WINDOW,
                                       END_WINDOW=END_WINDOW,
                                       TOTAL_TIMEPOINTS=TOTAL_TIMEPOINTS,
                                       N_ELECTRODES=n_electrodes)
    
    dataset = tf.data.TFRecordDataset(files, num_parallel_reads=1)
    dataset = dataset.map(decode_single_example_fn, num_parallel_calls=1)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    dataset = dataset.cache()
    #dataset = dataset.shuffle(BATCH_SIZE)
    dataset = dataset.repeat(repeat)

    return dataset

## Build fake data & convert it

In [5]:
n_electrodes = 8
n_timesteps = 5
n_batch = 200

data = np.zeros((n_batch, 
                n_electrodes,
                n_timesteps))

single_sample = np.zeros((n_electrodes,
                         n_timesteps))

for i in range(n_electrodes):
    single_sample[i, :] = np.arange(n_timesteps)
    
for i in range(n_batch):
    data[i,:] = single_sample

labels = np.arange(len(data))

Each single sample of data is 5 timesteps over 8 channels (like if it were 8 EEG electrodes). Each channel just has the same value of increasing integers. This is the format that the real data is in as it is encoded (see the input docstring for the `convert` function).

In [6]:
data[5]

array([[0., 1., 2., 3., 4.],
       [0., 1., 2., 3., 4.],
       [0., 1., 2., 3., 4.],
       [0., 1., 2., 3., 4.],
       [0., 1., 2., 3., 4.],
       [0., 1., 2., 3., 4.],
       [0., 1., 2., 3., 4.],
       [0., 1., 2., 3., 4.]])

In [7]:
# convert data and put 50 samples in each tfrecord
convert(data, labels, 50, fname="monday_testing")

Iteration 1/4, start=0, stop=50
Iteration 2/4, start=50, stop=100
Iteration 3/4, start=100, stop=150
Iteration 4/4, start=150, stop=200


Now, recover all tfrecords and extract the dataset based on those values

In [8]:
directory = pathlib.Path('.')
files = list(directory.glob('monday*.tfrecords'))
files = [str(x) for x in files]
print(files)
BATCH_SIZE = 1
ds = get_dataset(files, BATCH_SIZE, 1, 0, n_timesteps, n_timesteps, n_electrodes)

['monday_testing_file1.tfrecords', 'monday_testing_file2.tfrecords', 'monday_testing_file3.tfrecords', 'monday_testing_file4.tfrecords', 'monday_testing_file5.tfrecords']


In [9]:
for x, y in ds:
    if y.numpy() == 5:
        print(x)
        global test_sample
        test_sample = x.numpy()

tf.Tensor(
[[[0. 1. 2. 3. 4.]
  [0. 1. 2. 3. 4.]
  [0. 1. 2. 3. 4.]
  [0. 1. 2. 3. 4.]
  [0. 1. 2. 3. 4.]
  [0. 1. 2. 3. 4.]
  [0. 1. 2. 3. 4.]
  [0. 1. 2. 3. 4.]]], shape=(1, 8, 5), dtype=float32)


### Here we see, if we don't transpose, we recover the exact same information that went in. This is a good sign.

What we want is to have the time dimension first, then the channel dimension, because this is how temporal data is fed to models in Tensorflow / Flax. For sklearn and EEG-library (which works with sklearn) it's with the temporal dimension last.

So what we want the decoded output to look like is the transposed input, which is:

In [11]:
data[5].T

array([[0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [2., 2., 2., 2., 2., 2., 2., 2.],
       [3., 3., 3., 3., 3., 3., 3., 3.],
       [4., 4., 4., 4., 4., 4., 4., 4.]])

So if we define a new decoding function where we uncomment out the transpose line, if we can recover the above, then we've got the correct format.

In [12]:
def decode_single_example(serialised_example, START_WINDOW, END_WINDOW,
                           TOTAL_TIMEPOINTS, N_ELECTRODES):

    data_dim = TOTAL_TIMEPOINTS*N_ELECTRODES
    #print(f'expected data_dim = {TOTAL_TIMEPOINTS}*64={data_dim}')

    context_desc = {'label': tf.io.FixedLenFeature([], dtype=tf.int64)}
    feature_desc = {
        'eeg_data': tf.io.FixedLenSequenceFeature([data_dim], dtype=tf.float32)
    }

    context, data = tf.io.parse_single_sequence_example(
        serialized=serialised_example,
        context_features=context_desc,
        sequence_features=feature_desc,
        name='parsing_single_seq_example')

    data = data['eeg_data']
    data = tf.reshape(data, (N_ELECTRODES, TOTAL_TIMEPOINTS))
    data = tf.transpose(data, (1,0))

    data = data[START_WINDOW:END_WINDOW, :] # Extract (potential sub-window)
    WINDOW_LENGTH = END_WINDOW - START_WINDOW
    
    # If using SVM and need to flatten again
    #data = tf.reshape(data, (1, WINDOW_LENGTH * N_ELECTRODES))

    label = context['label']
    label = tf.cast(label, tf.int32)

    return data, label

In [13]:
directory = pathlib.Path('.')
files = list(directory.glob('monday*.tfrecords'))
files = [str(x) for x in files]
print(files)
BATCH_SIZE = 1
ds = get_dataset(files, BATCH_SIZE, 1, 0, n_timesteps, n_timesteps, n_electrodes)

['monday_testing_file1.tfrecords', 'monday_testing_file2.tfrecords', 'monday_testing_file3.tfrecords', 'monday_testing_file4.tfrecords', 'monday_testing_file5.tfrecords']


In [14]:
for x, y in ds:
    if y.numpy() == 5:
        print(x)
        global test_sample
        test_sample = x.numpy()

tf.Tensor(
[[[0. 0. 0. 0. 0. 0. 0. 0.]
  [1. 1. 1. 1. 1. 1. 1. 1.]
  [2. 2. 2. 2. 2. 2. 2. 2.]
  [3. 3. 3. 3. 3. 3. 3. 3.]
  [4. 4. 4. 4. 4. 4. 4. 4.]]], shape=(1, 5, 8), dtype=float32)


The decoded sample we were taking as a checking example (5th sample in the fake dataset) matches exactly.