# Benchmarking Hub and Webdataset
This notebook provides the methods and results of comparing two dataset management packages, Hub and WebDataset, in terms of dataset access and iteration time. This experiment expands a [blog post released by PyTorch](https://pytorch.org/blog/efficient-pytorch-io-library-for-large-datasets-many-files-many-gpus/). We use the same sample code provided in the original article and we juxtapose it with an equivalent using Hub.

As in the original post, we use ImageNet Dataset. For webdataset, the data is sharded in the way as provided in the manual. The dataset was also converted into a Hub-compliant format.

## Method

We use AWS to run the benchmarks.

Specification of the machine used for benchmarking:

<table>
  <tr>
    <th>Machine</th>
    <td>AWS EC2 r5n.metal instance</td>
  </tr>
    <tr>
    <th>Memory</th>
    <td>768 GB</td>
  </tr>
    <tr>
    <th>CPU</th>
    <td>Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz</td>
  </tr>
    <tr>
    <th>#vCPU</th>
    <td>96</td>
  </tr>
</table>

The data is stored locally within the instance storage.

Variable parameters include:
* number of workers (ranging from 1 to 24)
* batch size (currently mimicking the original post: 1000)

## Installing the Dependencies

First of all, we gather all the dependencies as instructed by the [tmbdev/pytorch-imagenet-wds](https://github.com/tmbdev/pytorch-imagenet-wds) repository in order to set up the environment. The hub, torch and webdataset versions are specified for reproducibility.

In [None]:
!pip install hub==1.2.3
!pip install torch==1.8.1
!pip install webdataset==0.1.54
!pip install torchvision
!pip install braceexpand
!pip install numpy
!pip install scipy
!pip install tk
!pip install matplotlib

Now that the dependencies have been installed, we can focus on sharding the dataset for WebDataset. Following the instructions in the repository mentioned above, we have gotten the following script for sharding the ImageNet Dataset.

There are a few parameters which can be adjusted accordingly to make the primary functions run without errors.

In [None]:
import sys
import os
import os.path
import random
import argparse
from torchvision import datasets
import webdataset as wds
import torch
import torchvision
import time
import numpy as np

## Preparing the Webdataset

The following cell uses code from WebDataset tutorial. You need to run the following cell only once to shard the data.

In [None]:
maxsize = 1e9
maxcount = 1000
filekey = False
data = "./data"



if not os.path.isdir(os.path.join(data, "train")):
    print(f"{data}: should be directory containing ImageNet", file=sys.stderr)
    print(f"suitable as argument for torchvision.datasets.ImageNet(...)", file=sys.stderr)
    sys.exit(1)


if not os.path.isdir(os.path.join(shards, ".")):
    print(f"{shards}: should be a writable destination directory for shards", file=sys.stderr)
    sys.exit(1)


def readfile(fname):
    "Read a binary file from disk."
    with open(fname, "rb") as stream:
        return stream.read()


all_keys = set()


def write_dataset(imagenet, base="./shards", split="train"):

    # We're using the torchvision ImageNet dataset
    # to parse the metadata; however, we will read
    # the compressed images directly from disk (to
    # avoid having to reencode them)
    ds = datasets.ImageNet(imagenet, split=split)
    nimages = len(ds.imgs)
    print("# nimages", nimages)

    # We shuffle the indexes to make sure that we
    # don't get any large sequences of a single class
    # in the dataset.
    indexes = list(range(nimages))
    random.shuffle(indexes)

    # This is the output pattern under which we write shards.
    pattern = os.path.join(base, f"imagenet-{split}-%06d.tar")

    with wds.ShardWriter(pattern, maxsize=int(maxsize), maxcount=int(maxcount)) as sink:
        for i in indexes:

            # Internal information from the ImageNet dataset
            # instance: the file name and the numerical class.
            fname, cls = ds.imgs[i]
            assert cls == ds.targets[i]

            # Read the JPEG-compressed image file contents.
            image = readfile(fname)

            # Construct a uniqu keye from the filename.
            key = os.path.splitext(os.path.basename(fname))[0]

            # Useful check.
            assert key not in all_keys
            all_keys.add(key)

            # Construct a sample.
            xkey = key if filekey else "%07d" % i
            sample = {"__key__": xkey, "jpg": image, "cls": cls}

            # Write the sample to the sharded tar archives.
            sink.write(sample)


# Data for WebDataset, must be in torch.Datasets.ImageNet compatible format.
write_dataset(data)
webdataset_url = "/shards/imagenet-train-{000000..001281}.tar"

## Preparing Hub Dataset

The dataset in hub format just needs to be pulled from S3 bucket to the local instance. It can also be directly streamed from S3.

In [None]:
hub_url = "./imagenet-hub"
s3_url = "s3://internal-datasets/imagenet-classification/imagenet2012/"
dataset = hub.Dataset(s3_url)
dataset.copy(hub_url)

# Timing the Read Access of WebDataset

Now that the dataset has been sharded for WebDataset, we can start making the dataloaders to iterate over the dataset and time the read access overhead.

We define the parameters with which we want to test the functions.

In [None]:
WORKERS = [24, 16, 8, 4]
batch_size = 1000


def employ(workers):
    def decorator(f):
        def wrapper(*args):
            times = []
            for n in workers:
                times.append(f(*args, n))
            return np.round(times, 3)
        return wrapper
    return decorator

In [None]:
@employ(WORKERS)
def time_webdataset(url, batch_size, num_workers=1):
    dataset = wds.Dataset(url)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers
    )
    start = time.time()
    for  name, inputs, targets in loader:
        _, x, y =  name, inputs, targets
    end = time.time()
    return end - start

# Timing the Read Access of Hub converted to PyTorch

Since WebDataset is based on PyTorch and Hub offers PyTorch integration, it would be useful to compare Hub's performance when converted to PyTorch locally.

In [None]:
@employ(WORKERS)
def time_hub(url, batch_size, num_workers=1):
    dataset = hub.Dataset(url)
    dataset = dataset.to_pytorch()
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        collate_fn=lambda b: b
    )
    start = time.time()
    for batch in loader:
        pass
    end = time.time()
    return end - start

## Running the Experiment

Now that we have all the utility functions, we can run the functions.

In [None]:
time_webdataset(webdataset_url, batch_size)

array([232.544, 252.853, 235.366, 198.742])

In [None]:
time_hub(hub_url, batch_size)

array([408.312, 375.634, 417.064, 477.035])

To improve Hub's performance, we use the remote version of Hub with a smaller batch size.

In [None]:
batch_size = 96

In [None]:
time_hub(hub_url, batch_size)

array([270.917, 254.519, 251.943, 289.542])

We can also test Hub on streaming data remotely from S3.

In [None]:
time_hub(s3_url, batch_size)

array([1688.301, 2683.032, 4825.543, 7982.483])

The result is that Webdataset is 1.007-2.400x faster than Hub, depending on the configurations. Essentially, their performance is roughly the same, with a minor advantage of Webdataset, however given how much time is saved by avoiding any preprocessing with Hub, it is a more optimal choice for most dataset users.