In [None]:
import hub
import tensorflow as tf
from time import sleep
import numpy as np
from matplotlib import pyplot as plt

In [None]:
# helper function to visualize images
def visualize(image):
    image = image.reshape(512, 512)
    plt.figure(figsize=(5, 5))
    plt.axis('off')
    plt.imshow(image, cmap='gray', vmin=0, vmax=1)
    


In [None]:
ds = hub.Dataset("some url") 
print(ds.shape) # the number of samples 
print(ds.schema) # the structure of the dataset

In [None]:
image_sequence = ds["image", 100040].compute() # or access any other sample
# visualize(img)
image_sequence.shape
visualize(image_sequence[0]) # visualize first image in sequence


In [None]:
for item in ds:
    print(item["label_chexpert"].compute()) # or you can access any other key from schema
    print(item["viewPosition"].compute()) # the ClassLabels are stored as integers
    print(item["viewPosition"].compute(label_name=True)) # strings labels are retrieved in this manner
    break


In [None]:
subset = ds[500:1000] # take a subset of the dataset 
print(len(subset))


In [None]:
def only_frontal(sample):
    viewPosition = sample["viewPosition"].compute(True)
    return True if "PA" in viewPosition or "AP" in viewPosition else False

filtered = subset.filter(only_frontal)
print(len(filtered))


In [None]:
tds = filtered.to_tensorflow()
# alternatively we can send a subset of keys to tf that are relevant for training
# this is faster as otherwise other irrelevant data is fetched too, that can slow things down
tds = filtered.to_tensorflow(key_list=["image", "label_chexpert", "viewPosition"])


In [None]:
def get_image(viewPosition, images):
    for i, vp in enumerate(viewPosition):
        if vp in [5, 12]:
            return np.concatenate((images[i], images[i], images[i]), axis=2)

def to_model_fit(sample):
    viewPosition = sample["viewPosition"]
    images = sample["image"]
    image = tf.py_function(get_image, [viewPosition, images], tf.uint16)
    labels = sample["label_chexpert"]
    return image, labels

# converts the data into X, y format format for training
tds_train = tds.map(to_model_fit)

# batch and prefetch
tds_train = tds_train.batch(8).prefetch(tf.data.AUTOTUNE)

In [47]:
for batch in tds_train:
    # do something
    sleep(0.1) # simulate training delay 