In [1]:
import pathlib
import tensorflow as tf
import xarray as xr

# Python generator to `tf.data.Dataset`

Once the feedloop is setup as a python generator, it can be converted in to a tensorflow `Dataset` object for optimal ingestion into Keras. We exemplify this here, using a generator that reads tiles in TIFF format from disk:

In [2]:
# path with TIFF files
tile_path = "./retiled_test"

In [3]:
# generator - much simplified version of the feed loop, reading cutouts from disk
def get_tile(tile_path):
    # string input is passed by tf as byte object, need to decode
    tile_path = tile_path.decode() if isinstance(tile_path, bytes) else tile_path
    tile_path = pathlib.Path(tile_path)
    for file in tile_path.glob("*.tif"):
        yield xr.open_rasterio(file).data

In [4]:
# testing the iterator
iterator = get_tile(tile_path)
element = next(iterator)
element.dtype, element.shape




(dtype('uint16'), (3, 20, 20))

In [5]:
# set up Dataset object
ds = tf.data.Dataset.from_generator(
    get_tile, 
    args=[tile_path], 
    output_types=tf.int16, 
    output_shapes=(3, 20, 20)
)
ds

<FlatMapDataset shapes: (3, 20, 20), types: tf.int16>

In [6]:
# shuffle elements with a buffer of 100, making batches with 20 cutouts each
# (in first batch we can have elements from the first 120 files)
batches = ds.shuffle(buffer_size=100).batch(20, drop_remainder=True)
batches

<BatchDataset shapes: (20, 3, 20, 20), types: tf.int16>