## References

* https://github.com/webdataset/webdataset-tensorflow/blob/main/resnet-multi.py
* https://github.com/LAION-AI/LAION-SAFETY/blob/main/laionsafety.py

## Machine used

GCP n1-highmem (16 vCPUs, 104 GB RAM) with Debian 10

In [1]:
import tensorflow as tf
import numpy as np
import pprint
import webdataset as wds
from webdataset import multi
import typer

import os

In [3]:
url = "http://3080.rom1504.fr/cah/laion400m_porn_data/{00000..00046}.tar"
url = f"pipe:curl -L -s {url} || true"

In [4]:
def filter_dataset(item):  # For e.g. C@H which (rarely) has no caption available.
    if "txt" not in item:
        return False
    if "jpg" not in item:
        return False
    return True

In [5]:
class ImagenetData:
    """This class is a convenient placeholder for the dataset-related information.
    You could also just define these iterator etc. as global functions."""

    def __init__(self, url=url):
        self.url = url
        self.dataset = (
            wds.WebDataset(self.url, shardshuffle=True, handler=wds.ignore_and_continue)
            .select(filter_dataset)
            .decode("rgb")
            .to_tuple("jpg", "txt")
        )
        self.loader = multi.MultiLoader(self.dataset, workers=os.cpu_count())

    def __iter__(self):
        for img, hot in self.loader:
            yield img.astype("float32"), np.array(hot).astype(str)

    def output_shapes(self):
        return ((256, 256, 3), ())

    def output_types(self):
        return (tf.float32, tf.string)

In [6]:
def resize_images(image_batch, caption_batch):
    return tf.image.resize(image_batch, (260, 260), antialias=True), caption_batch

In [7]:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
    tf.data.experimental.AutoShardPolicy.DATA
)
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.apply_default_optimizations = True
options.experimental_optimization.filter_fusion = True

df = ImagenetData()
tdf = tf.data.Dataset.from_generator(
    generator=df.__iter__,
    output_types=df.output_types(),
    output_shapes=df.output_shapes(),
)
tdf = tdf.with_options(options)
tdf = tdf.batch(512).map(resize_images, num_parallel_calls=tf.data.AUTOTUNE)
tdf = tdf.prefetch(tf.data.AUTOTUNE)

2022-01-08 06:14:41.848510: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [8]:
for image_batch, caption_batch in tdf.take(1):
    print(image_batch.shape)
    print(caption_batch.shape)
    break

(512, 260, 260, 3)
(512,)


In [9]:
%%time

# Benchmaring for 10 batches.
for i, (batch, label) in enumerate(tdf.take(10)):
    if i % 40 == 0:
        print(".", end="")
print()

killing <Process(Process-1, started)>
killing <Process(Process-2, started)>
killing <Process(Process-3, started)>
killing <Process(Process-4, started)>
killing <Process(Process-5, started)>
killing <Process(Process-6, started)>
killing <Process(Process-7, started)>
killing <Process(Process-8, started)>
killing <Process(Process-9, started)>
killing <Process(Process-10, started)>
killing <Process(Process-11, started)>
killing <Process(Process-12, started)>
killing <Process(Process-13, started)>
killing <Process(Process-14, started)>
killing <Process(Process-15, started)>
killing <Process(Process-16, started)>
closing <zmq.Socket(zmq.PULL) at 0x7fe30c57d910>
.
CPU times: user 25.1 s, sys: 10.7 s, total: 35.8 s
Wall time: 13.6 s
