# Compare dataset types
This notebook contains code that compares the run time of querying dataset types. 

In [1]:
# imports
import numpy as np
import tensorflow as tf
import time
from komorebi.libs.utilities.io_utils import load_pickle_object
import random

## Sharded Implementation

In [2]:
sharded_dataset_path = "/Users/andy/Projects/biology/research/komorebi/data/attention_validation_dataset/sharded_attention_dataset.pkl"
sharded_dataset = load_pickle_object(sharded_dataset_path)

sharded_times = []
indices = [random.randint(0, 7999) for _ in xrange(100)]

for i in range(2):
    sharded_ts = time.time()
    sharded_dataset.get_training_examples(indices)
    sharded_te = time.time()
    sharded_times.append(sharded_te-sharded_ts) 

## Tensorflow implementation

In [4]:
SEQUENCE_SHAPE = (1000, 4)
ANNOTATION_SHAPE = (75, 320)

def parse_example(tf_example):
    """Parse tensorflow example"""
    
    features_map = {
        'sequence_raw': tf.FixedLenFeature([], tf.string),
        'label_raw': tf.FixedLenFeature([], tf.string),
        'annotation_raw': tf.FixedLenFeature([], tf.string)}
    
    parsed_example = tf.parse_single_example(tf_example, features_map)
    
    sequence_raw = tf.decode_raw(parsed_example['sequence_raw'], tf.uint8)
    annotation_raw = tf.decode_raw(parsed_example['annotation_raw'], tf.float32)
    
    sequence = tf.reshape(sequence_raw, SEQUENCE_SHAPE)
    label = tf.decode_raw(parsed_example['label_raw'], tf.uint8)
    annotation = tf.reshape(annotation_raw, ANNOTATION_SHAPE)
    
    return {'sequence': sequence, 'label': label, 'annotation': annotation}

TF_VALIDATION_DATASET = "/tmp/validation_dataset.tfrecord"
tf_dataset = tf.data.TFRecordDataset([TF_VALIDATION_DATASET])
tf_dataset = tf_dataset.prefetch(8000) #buffer size
tf_dataset = tf_dataset.map(parse_example, num_parallel_calls=6)
iterator = tf_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

batched_dataset = tf_dataset.batch(100)
batched_iter = batched_dataset.make_one_shot_iterator()

In [5]:
sess = tf.InteractiveSession()
tf_times = []

for _ in range(7):
    tf_ts = time.time()
    batched_next = batched_iter.get_next()
    b_sequence = batched_next['sequence']
    b_annotation = batched_next['annotation']
    b_label = batched_next['label']
    print b_sequence.eval().shape, b_annotation.eval().shape, b_label.eval().shape
    tf_te = time.time()
    tf_times.append(tf_te - tf_ts)

sess.close()

(100, 1000, 4) (100, 75, 320) (100, 919)
(100, 1000, 4) (100, 75, 320) (100, 919)
(100, 1000, 4) (100, 75, 320) (100, 919)
(100, 1000, 4) (100, 75, 320) (100, 919)
(100, 1000, 4) (100, 75, 320) (100, 919)
(100, 1000, 4) (100, 75, 320) (100, 919)
(100, 1000, 4) (100, 75, 320) (100, 919)


In [6]:
#print sharded_times
print tf_times

[4.647903919219971, 0.03866004943847656, 0.039613962173461914, 0.038726091384887695, 0.04016900062561035, 0.03953218460083008, 0.03882098197937012]
