# Optimize data load and preprocessing with tf.data

**Learning Objectives**
1. Learn how to use tf.data to read data from memory
1. Learn how to use tf.data to read data from disk
1. Learn how to write production input pipelines with feature engineering (batching, shuffling, etc.)
1. Learn how to optimize pipeline with tf.data


In this notebook, we will start by refactoring the linear regression we implemented in the previous lab so that it takes its data from a`tf.data.Dataset`, and we will learn how to implement **stochastic gradient descent** with it. In this case, the original dataset will be synthetic and read by the `tf.data` API directly from memory.

We will use TensorFlow for framework, but **tf.data works with any frameworks like JAX or Pytorch**.

In a second part, we will learn how to load a dataset with the `tf.data` API when the dataset resides on disk, and then learn how to optimize the data pipeline.

In [1]:
import os
import warnings

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
warnings.filterwarnings("ignore")

import json
import math
from pprint import pprint

import numpy as np
import tensorflow as tf

print(tf.version.VERSION)

2025-09-03 16:28:22.827584: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756916902.850386 2956739 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756916902.857055 2956739 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


2.18.1


## Loading data from memory

### Creating the dataset

Let's consider the synthetic dataset of the previous section:

In [2]:
N_POINTS = 10
X = tf.constant(range(N_POINTS), dtype=tf.float32)
Y = 2 * X + 10

W0000 00:00:1756916905.676411 2956739 gpu_device.cc:2344] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


We begin by implementing a function that takes as input

- our $X$ and $Y$ vectors of synthetic data generated by the linear function $y= 2x + 10$
- the number of passes over the dataset we want to train on (`epochs`)
- the size of the batches in the dataset (`batch_size`)
and returns a `tf.data.Dataset`: 

**Remark:** Note that the last batch may not contain the exact number of elements you specified because the dataset was exhausted.

If you want batches with the exact same number of elements per batch, we will have to discard the last batch by
setting:

```python
dataset = dataset.batch(batch_size, drop_remainder=True)
```

We will do that here.

In [3]:
def create_dataset(X, Y, epochs, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((X, Y))
    dataset = dataset.repeat(epochs).batch(batch_size, drop_remainder=True)
    return dataset

Let's test our function by iterating twice over our dataset in batches of 3 data points:

In [4]:
BATCH_SIZE = 3
EPOCH = 2

dataset = create_dataset(X, Y, epochs=EPOCH, batch_size=BATCH_SIZE)

for i, (x, y) in enumerate(dataset):
    print("x:", x.numpy(), "y:", y.numpy())
    assert len(x) == BATCH_SIZE
    assert len(y) == BATCH_SIZE
assert EPOCH

x: [0. 1. 2.] y: [10. 12. 14.]
x: [3. 4. 5.] y: [16. 18. 20.]
x: [6. 7. 8.] y: [22. 24. 26.]
x: [9. 0. 1.] y: [28. 10. 12.]
x: [2. 3. 4.] y: [14. 16. 18.]
x: [5. 6. 7.] y: [20. 22. 24.]


### Loss function and gradients

The loss function and the function that computes the gradients are the same as before:

In [5]:
def loss_mse(X, Y, w0, w1):
    Y_hat = w0 * X + w1
    errors = (Y_hat - Y) ** 2
    return tf.reduce_mean(errors)


def compute_gradients(X, Y, w0, w1):
    with tf.GradientTape() as tape:
        loss = loss_mse(X, Y, w0, w1)
    return tape.gradient(loss, [w0, w1])

### Training loop

The main difference now is that now, in the training loop, we will iterate directly on the `tf.data.Dataset` generated by our `create_dataset` function. 

We will configure the dataset so that it iterates 250 times over our synthetic dataset in batches of 2.

In [6]:
EPOCHS = 250
BATCH_SIZE = 2
LEARNING_RATE = 0.02

MSG = "STEP {step} - loss: {loss}, w0: {w0}, w1: {w1}\n"

w0 = tf.Variable(0.0)
w1 = tf.Variable(0.0)

dataset = create_dataset(X, Y, epochs=EPOCHS, batch_size=BATCH_SIZE)

for step, (X_batch, Y_batch) in enumerate(dataset):
    dw0, dw1 = compute_gradients(X_batch, Y_batch, w0, w1)
    w0.assign_sub(dw0 * LEARNING_RATE)
    w1.assign_sub(dw1 * LEARNING_RATE)

    if step % 100 == 0:
        loss = loss_mse(X_batch, Y_batch, w0, w1)
        print(MSG.format(step=step, loss=loss, w0=w0.numpy(), w1=w1.numpy()))

assert loss < 0.0001
assert abs(w0 - 2) < 0.001
assert abs(w1 - 10) < 0.001

STEP 0 - loss: 109.76800537109375, w0: 0.23999999463558197, w1: 0.4399999976158142

STEP 100 - loss: 9.363959312438965, w0: 2.55655837059021, w1: 6.674341678619385

STEP 200 - loss: 1.393267273902893, w0: 2.2146825790405273, w1: 8.717182159423828

STEP 300 - loss: 0.20730558037757874, w0: 2.082810878753662, w1: 9.505172729492188

STEP 400 - loss: 0.03084510937333107, w0: 2.03194260597229, w1: 9.809128761291504

STEP 500 - loss: 0.004589457996189594, w0: 2.012321710586548, w1: 9.926374435424805

STEP 600 - loss: 0.0006827632314525545, w0: 2.0047526359558105, w1: 9.971602439880371

STEP 700 - loss: 0.00010164897685172036, w0: 2.0018346309661865, w1: 9.989042282104492

STEP 800 - loss: 1.5142451957217418e-05, w0: 2.000706911087036, w1: 9.995771408081055

STEP 900 - loss: 2.256260358990403e-06, w0: 2.0002737045288086, w1: 9.998367309570312

STEP 1000 - loss: 3.3405058275093324e-07, w0: 2.000105381011963, w1: 9.999371528625488

STEP 1100 - loss: 4.977664502803236e-08, w0: 2.000040054321289,

## Loading data from disk

### Locating the CSV files

We will start with the **taxifare dataset** CSV files that we wrote out in a previous lab. 

The taxifare dataset files have been saved into `../data`.

Check that it is the case in the cell below, and, if not, regenerate the taxifare
dataset by running the previous lab notebook:

In [7]:
!ls -l ../data/taxi*.csv

-rw-r--r-- 1 jupyter jupyter 123675 Aug 29 17:55 ../data/taxi-test.csv
-rw-r--r-- 1 jupyter jupyter 579140 Aug 29 17:55 ../data/taxi-train.csv
-rw-r--r-- 1 jupyter jupyter 399647 Aug 29 17:55 ../data/taxi-valid.csv


### Use Low-level tf.data API to read the CSV files

To get a more flexible pipeline, we can utilize low-level tf.data APIs to fully control the behavior of the pipeline.

For text-based data including CSV, we can use `TextLineDataset` to load data.

In [8]:
ds = tf.data.TextLineDataset("../data/taxi-train.csv")
ds

<TextLineDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

Note that the Dataset object (`ds`) is still just a definition, and it hasn't loaded the actual data yet.<br>
Let's iterate over the first two elements of this dataset using `dataset.take(2)`:

In [9]:
for data in ds.take(2):
    print(data)

tf.Tensor(b'fare_amount,pickup_datetime,pickuplon,pickuplat,dropofflon,dropofflat,passengers,key', shape=(), dtype=string)
tf.Tensor(b'11.3,2011-01-28 20:42:59 UTC,-73.999022,40.739146,-73.990369,40.717866,1,0', shape=(), dtype=string)


It seems it loads the header row as the first element. Since it's not part of the training data, lets' skip it with the `skip()` method.

In [10]:
ds = tf.data.TextLineDataset("../data/taxi-train.csv").skip(1)

for data in ds.take(2):
    print(data)

tf.Tensor(b'11.3,2011-01-28 20:42:59 UTC,-73.999022,40.739146,-73.990369,40.717866,1,0', shape=(), dtype=string)
tf.Tensor(b'7.7,2011-06-27 04:28:06 UTC,-73.987443,40.729221,-73.979013,40.758641,1,1', shape=(), dtype=string)


### Transforming the features with `.map()`

At this point, we've loaded the CSV file as a text file, and each row was simply represented as a single string value containing speparators (`,`).

Let's write a parsing function that takes a row and splits it into multiple values.

In [11]:
def parse_csv(row):
    return tf.strings.split(row, ",")

Let's make sure it works by calling this function in the for loop with `.take()`. 

In [12]:
ds = tf.data.TextLineDataset("../data/taxi-train.csv").skip(1)

for data in ds.take(2):
    values = parse_csv(data)
    pprint(values)

<tf.Tensor: shape=(8,), dtype=string, numpy=
array([b'11.3', b'2011-01-28 20:42:59 UTC', b'-73.999022', b'40.739146',
       b'-73.990369', b'40.717866', b'1', b'0'], dtype=object)>
<tf.Tensor: shape=(8,), dtype=string, numpy=
array([b'7.7', b'2011-06-27 04:28:06 UTC', b'-73.987443', b'40.729221',
       b'-73.979013', b'40.758641', b'1', b'1'], dtype=object)>


Instead of calling the function in a for loop, we can wrap it in a `.map()` method to include it in a pipeline.

In [13]:
ds = tf.data.TextLineDataset("../data/taxi-train.csv").skip(1).map(parse_csv)

for data in ds.take(2):
    print(data)

tf.Tensor(
[b'11.3' b'2011-01-28 20:42:59 UTC' b'-73.999022' b'40.739146'
 b'-73.990369' b'40.717866' b'1' b'0'], shape=(8,), dtype=string)
tf.Tensor(
[b'7.7' b'2011-06-27 04:28:06 UTC' b'-73.987443' b'40.729221'
 b'-73.979013' b'40.758641' b'1' b'1'], shape=(8,), dtype=string)


Now let's extend the `parse_csv` function.<br>
In machine learning training, we want to pass training data in tuples `(features, label)`.

In this CSV file we have these columns:

In [14]:
!head -1 ../data/taxi-train.csv

fare_amount,pickup_datetime,pickuplon,pickuplat,dropofflon,dropofflat,passengers,key


Let's say we want to predict the `fare_amount` value, using `pickuplon`, `pickuplat`, `dropofflon` and `dropofflat` as features. If so:

In [15]:
def parse_csv(row):
    ds = tf.strings.split(row, ",")
    # Label: fare_amount
    label = tf.strings.to_number(ds[0])
    # Feature: pickup_longitude, pickup_latitude, dropoff_longitude, dropoff_latitude
    features = tf.strings.to_number(ds[2:6])
    return features, label

In [16]:
ds = tf.data.TextLineDataset("../data/taxi-train.csv").skip(1)
ds = ds.map(parse_csv)

for features, label in ds.take(2):
    print(f"features: \n  {features}, \nlabel: \n  {label} \n++++")

features: 
  [-73.99902   40.739147 -73.99037   40.717865], 
label: 
  11.300000190734863 
++++
features: 
  [-73.98744  40.72922 -73.97901  40.75864], 
label: 
  7.699999809265137 
++++


### Batching

Typically, a machine learning training module requires batched data. Let's refactor our pipeline to batch the data by adding `.batch(BATCH_SIZE)`.

In [17]:
BATCH_SIZE = 4

ds = tf.data.TextLineDataset("../data/taxi-train.csv").skip(1)
ds = ds.map(parse_csv).batch(BATCH_SIZE)

for features, label in ds.take(2):
    print(f"features: \n  {features}, \nlabel: \n  {label} \n++++")

features: 
  [[-73.99902   40.739147 -73.99037   40.717865]
 [-73.98744   40.72922  -73.97901   40.75864 ]
 [-73.98254   40.735725 -73.954796  40.77839 ]
 [-74.001945  40.740505 -73.91385   40.75856 ]], 
label: 
  [11.3  7.7 10.5 16.2] 
++++
features: 
  [[-73.99337   40.753384 -73.8609    40.7329  ]
 [-73.99624   40.721848 -73.98942   40.718052]
 [-73.97705   40.75846  -73.9849    40.744694]
 [-73.9694    40.757545 -73.95005   40.776077]], 
label: 
  [33.5  6.9  6.1  9.5] 
++++


Now our dataset is an iterator of *batches*, instead of *rows*, which is suitable for mini-batch training for neural networks.

### Shuffling & Repeating

When training a deep learning model in batches over multiple workers, it is helpful if we shuffle the data. That way, different workers will be working on different parts of the input file at the same time, and so averaging gradients across workers will help.<br>
We can add shuffling with `.shuffle()`. But please note that the shuffle buffer specified in `buffer_size` will be stored on memory, and it is not suitable for full shuffling on very large scale datasets.

Let's wrap our data pipeline in a `create_dataset` function so that we can control its behaviour and shuffle data only when the dataset is used for training.

We will introduce an additional argument `mode` to our function to allow the function body to distinguish the case when it needs to shuffle the data (`mode == "train"`) from when it shouldn't (`mode == "eval"`).

Also, let's add `.repeat()` to read the data indefinitely during training. 

In [18]:
def create_dataset(pattern, batch_size, mode="eval"):
    ds = tf.data.TextLineDataset(pattern).skip(1)
    ds = ds.map(parse_csv).repeat()
    if mode == "train":
        ds = ds.shuffle(buffer_size=1000)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

Let's check that our function works well in both modes:

In [19]:
# Run this cell multiple times to see the results are different.
tempds = create_dataset("../data/taxi-train.csv", 2, "train")
print(list(tempds.take(1)))

[(<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
array([[-73.94901 ,  40.77736 , -73.991936,  40.74759 ],
       [-73.98325 ,  40.756077, -73.90924 ,  40.765625]], dtype=float32)>, <tf.Tensor: shape=(2,), dtype=float32, numpy=array([14.9,  9.7], dtype=float32)>)]


In [20]:
tempds = create_dataset("../data/taxi-valid.csv", 2, "eval")
print(list(tempds.take(1)))

[(<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
array([[-73.96768 ,  40.79274 , -73.9689  ,  40.791676],
       [-74.00532 ,  40.72768 , -73.97028 ,  40.75662 ]], dtype=float32)>, <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 6.1, 16. ], dtype=float32)>)]


## Better Performance with tf.Data

Maximizing the performance of data loading and preprocessing phase is critical for many machine learning use cases.

`tf.data` offers a number of ways to optimize the process, depending on the cause of performance bottlenecks.<br>
Let's take a look at some scenarios.

For comparison, we use this `benchmark` function that simulates a training application loop.

In [21]:
import time


def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    print("Execution time:", time.perf_counter() - start_time)

### Case 1: Performance bottleneck in heavy map operation 

While feature transformation `.map()` is flexible and convenient, this process can be a performance bottleneck when the preprocessing function contains heavy operations.

Let's simulate that case by adding sleep time into our parse function.

In [22]:
def heavy_parse_csv(row):
    ds = tf.strings.split(row, ",")
    label = tf.strings.to_number(ds[0])
    features = tf.strings.to_number(ds[2:6])

    # Perform a heavy preprocessing...
    tf.py_function(lambda: time.sleep(0.001), [], ())

    return features, label

In [23]:
def create_dataset(pattern, batch_size=128):
    ds = tf.data.TextLineDataset(pattern).skip(1)
    ds = ds.map(heavy_parse_csv)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

In [24]:
tempds = create_dataset("../data/taxi-train.csv")
benchmark(tempds)

Execution time: 19.778359695977997


The flow looks like this. The map operation between data read and training is the bottleneck in this case.

![Map bottleneck](https://www.tensorflow.org/guide/images/data_performance/sequential_map.svg)

Let's see how we can optimize this process.

#### Solution 1: Parallelize map

Because input elements are independent of one another, the pre-processing can be parallelized across multiple CPU cores. To make this possible, the map transformation provides the num_parallel_calls argument to specify the level of parallelism.

In `.map()` you can specify the `num_parallel_calls` arg along with the function. The number of parallelism can be auto-tuned by specifying `tf.data.AUTOTUNE`.

In [25]:
def create_dataset(pattern, batch_size=128):
    ds = tf.data.TextLineDataset(pattern).skip(1)
    ds = ds.map(heavy_parse_csv, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

In [26]:
tempds = create_dataset("../data/taxi-train.csv")
benchmark(tempds)

Execution time: 5.041047303995583


Now it is much faster!

The flow now looks like this.
![parallelized](https://www.tensorflow.org/guide/images/data_performance/parallel_map.svg)

####Â Solution 2: Vectorize the map operation
`.map()` processes each individual element returned by a `Dataset`. Our current function is structured to work on one CSV element at a time. However, processing data in batches is always more efficient when feasible.

Let's vectorize our function (that is, have it operate over a batch of inputs at once) and apply the `batch` transformation before the `map` transformation.

In [27]:
def heavy_parse_csv_batch(row):
    ds = tf.strings.split(row, ",").to_tensor()
    label = tf.strings.to_number(ds[:, 0])
    features = tf.strings.to_number(ds[:, 2:6])

    # Perform a heavy preprocessing...
    tf.py_function(lambda: time.sleep(0.001), [], ())

    return features, label

In [28]:
def create_dataset(pattern, batch_size=128):
    ds = tf.data.TextLineDataset(pattern).skip(1)
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.map(heavy_parse_csv_batch, num_parallel_calls=tf.data.AUTOTUNE)
    return ds

In [29]:
tempds = create_dataset("../data/taxi-train.csv")
benchmark(tempds)

Execution time: 1.23425766202854


### Case 2: Performance Bottleneck in I/O (Data Loading) 
Let's take a look at the next scenario.

In a real-world setting, the input data may be stored remotely (for example on Google Cloud Storage in a different location). A dataset pipeline that works well when reading data locally might become bottlenecked on I/O when reading data remotely because of the following differences between local and remote storage:

- **Time-to-first-byte**: Reading the first byte of a file from remote storage can take orders of magnitude longer than from local storage.
- **Read throughput**: While remote storage typically offers large aggregate bandwidth, reading a single file might only be able to utilize a small fraction of this bandwidth.

In addition, once the raw bytes are loaded into memory, it may also be necessary to deserialize and/or decrypt the data (e.g. protobuf), which requires additional computation. This overhead is present irrespective of whether the data is stored locally or remotely, but can be worse in the remote case if data is not prefetched effectively.

Let's create a custom dataset to simulate this scenario.

In [30]:
class IOBoundDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.3)

        for sample_idx in range(num_samples):
            # Reading each line from the file
            time.sleep(0.15)

            yield (sample_idx,)

    def __new__(cls, file_name, num_samples=5):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature=tf.TensorSpec(shape=(1,), dtype=tf.int64),
            args=(num_samples,),
        )

In [31]:
ds = IOBoundDataset("dummy_file.csv").repeat(20)
benchmark(ds)

Execution time: 42.26503488502931


#### Solution 1: Cache
Since machine learning training often involves using the same dataset repeatedly, a good strategy is to cache the data during the first epoch and then retrieve it from the cache in subsequent epochs, rather than reloading it from a remote source each time.

You can simply insert `.cache()` to use caching.

In [32]:
ds = IOBoundDataset("dummy_file.csv").cache().repeat(20)
benchmark(ds)

Execution time: 3.1102266230154783


It looks much faster! However, be careful about what you cache since the cached data will be stored on memory. For example, it's not realistic to cache a terabyte-scale dataset.

For example, if you have a CSV file that contains paths of videos for training, instead of caching the actual video file, consider caching the small CSV file by inserting `cache()` before the video load function.

```python 
dataset.map(parse_csv_fn).cache().map(load_video_fn)
```

#### Solution 2: Interleave

To mitigate the impact of the various data extraction overheads, the tf.data.Dataset.interleave transformation can be used to parallelize the data loading step, interleaving the contents of other datasets (such as data file readers).

Let's say we have multiple sharded files.

In [33]:
files = [
    "dummy_file_shard001.csv",
    "dummy_file_shard002.csv",
    "dummy_file_shard003.csv",
]

You can insert `interleave()` to interleave multiple file load operations.

In [34]:
ds = tf.data.Dataset.from_tensor_slices(files)
ds = ds.interleave(IOBoundDataset)

benchmark(ds)

Execution time: 6.420450509001967


Let's take a look at how they are loaded.

In [35]:
for d in ds:
    print(d)

tf.Tensor([0], shape=(1,), dtype=int64)
tf.Tensor([0], shape=(1,), dtype=int64)
tf.Tensor([0], shape=(1,), dtype=int64)
tf.Tensor([1], shape=(1,), dtype=int64)
tf.Tensor([1], shape=(1,), dtype=int64)
tf.Tensor([1], shape=(1,), dtype=int64)
tf.Tensor([2], shape=(1,), dtype=int64)
tf.Tensor([2], shape=(1,), dtype=int64)
tf.Tensor([2], shape=(1,), dtype=int64)
tf.Tensor([3], shape=(1,), dtype=int64)
tf.Tensor([3], shape=(1,), dtype=int64)
tf.Tensor([3], shape=(1,), dtype=int64)
tf.Tensor([4], shape=(1,), dtype=int64)
tf.Tensor([4], shape=(1,), dtype=int64)
tf.Tensor([4], shape=(1,), dtype=int64)


Each dataset contains values from 0 to 5, and here we can see the data load from 3 files are interleaved.

This is the flow image of interleaving (with 2 files in this case)
![sequential interleave](https://www.tensorflow.org/guide/images/data_performance/sequential_interleave.svg)

Like `.map()`, you can parallelize this interleave operation by adding `num_parallel_calls` for further performance gain.

In [36]:
ds = tf.data.Dataset.from_tensor_slices(files)
ds = ds.interleave(IOBoundDataset, num_parallel_calls=len(files))

benchmark(ds)

Execution time: 2.24809078394901


![parallel interleave](https://www.tensorflow.org/guide/images/data_performance/parallel_interleave.svg)

By conbining these techniques, you can design a highly optimized data load and transform pipeline.

Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.