# Benchmarks: Hub v/s WebDataset
In this notebook we compare the packages Hub and WebDataset for their read timings. 

Both use a backend involving Sharded Datasets. For this experiment we will be using the ImageNet Dataset.

## Installing the Dependencies

First of all, we gather all the dependencies as instructed by the repository 

[https://github.com/tmbdev/pytorch-imagenet-wds](https://github.com/tmbdev/pytorch-imagenet-wds)

This will help us set up the environment for WebDataset

In [None]:
!pip install hub
!pip install webdataset
!pip install braceexpand
!pip install numpy
!pip install scipy
!pip install tk
!pip install matplotlib
!pip install torch
!pip install torchvision

Now that the dependencies have been installed, we can focus on sharding the dataset for WebDataset. Following the instructions in the repository mentioned above, we have gotten the following script for sharding the ImageNet Dataset.

There are a few parameters which can be adjusted accordingly to make the primary functions run without errors.

In [None]:
import sys
import os
import os.path
import random
import argparse
from torchvision import datasets
import webdataset as wds
import torch
import torchvision
import time

In [None]:
'''Parameters (not attatched to parser)'''
maxsize = 1e9
maxcount = 1000
filekey = False
data = "./data"
shards = "./shards"
splits = "train,val"

'''Original Code'''
# parser = argparse.ArgumentParser("""Generate sharded dataset from original ImageNet data.""")
# parser.add_argument("--splits", default="train,val", help="which splits to write")
# parser.add_argument(
#     "--filekey", action="store_true", help="use file as key (default: index)"
# )
# parser.add_argument("--maxsize", type=float, default=1e9)
# parser.add_argument("--maxcount", type=float, default=1000)
# parser.add_argument(
#     "--shards", default="./shards", help="directory where shards are written"
# )
# parser.add_argument(
#     "--data",
#     default="./data",
#     help="directory containing ImageNet data distribution suitable for torchvision.datasets",
# )
# args = parser.parse_args()


# assert args.maxsize > 10000000
# assert args.maxcount < 1000000



if not os.path.isdir(os.path.join(data, "train")):
    print(f"{data}: should be directory containing ImageNet", file=sys.stderr)
    print(f"suitable as argument for torchvision.datasets.ImageNet(...)", file=sys.stderr)
    sys.exit(1)


if not os.path.isdir(os.path.join(shards, ".")):
    print(f"{shards}: should be a writable destination directory for shards", file=sys.stderr)
    sys.exit(1)


splits = splits.split(",")


def readfile(fname):
    "Read a binary file from disk."
    with open(fname, "rb") as stream:
        return stream.read()


all_keys = set()


def write_dataset(imagenet, base="./shards", split="train"):

    # We're using the torchvision ImageNet dataset
    # to parse the metadata; however, we will read
    # the compressed images directly from disk (to
    # avoid having to reencode them)
    ds = datasets.ImageNet(imagenet, split=split)
    nimages = len(ds.imgs)
    print("# nimages", nimages)

    # We shuffle the indexes to make sure that we
    # don't get any large sequences of a single class
    # in the dataset.
    indexes = list(range(nimages))
    random.shuffle(indexes)

    # This is the output pattern under which we write shards.
    pattern = os.path.join(base, f"imagenet-{split}-%06d.tar")

    with wds.ShardWriter(pattern, maxsize=int(maxsize), maxcount=int(maxcount)) as sink:
        for i in indexes:

            # Internal information from the ImageNet dataset
            # instance: the file name and the numerical class.
            fname, cls = ds.imgs[i]
            assert cls == ds.targets[i]

            # Read the JPEG-compressed image file contents.
            image = readfile(fname)

            # Construct a uniqu keye from the filename.
            key = os.path.splitext(os.path.basename(fname))[0]

            # Useful check.
            assert key not in all_keys
            all_keys.add(key)

            # Construct a sample.
            xkey = key if filekey else "%07d" % i
            sample = {"__key__": xkey, "jpg": image, "cls": cls}

            # Write the sample to the sharded tar archives.
            sink.write(sample)

for split in splits:
    print("# split", split)
    write_dataset(data, base=shards, split=split)

# Timing the Read Access of WebDataset

Now that the dataset has been sharded for WebDataset, we can start making the dataloaders to iterate over the dataset and time the read access overhead.

In [None]:
def time_webdataset(url, batch_size=64):
    dataset = wds.Dataset(url)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=8)
    start = time.time()
    for filename, inputs, targets in loader:
        x, y = inputs, targets
    end = time.time()
    print("Time taken by WebDataset: ", end - start)

# Timing the Read Access of Hub

The following cell aims to time the package "Hub" for the same read access overhead.

In [None]:
def time_hub(tag, batch_size=64):
    dataset = hub.Dataset(tag)
    start = time.time()
    for batch in range(dataset.shape[0] // batch_size):
        x = dataset["image"][batch * batch_size : (batch + 1) * batch_size].compute()
        y = dataset["label"][batch * batch_size : (batch + 1) * batch_size].compute()
    if dataset.shape[0] % batch_size != 0:
        x = dataset["image"][(dataset.shape[0] // batch_size) * batch_size : ].compute()
        y = dataset["label"][(dataset.shape[0] // batch_size) * batch_size : ].compute()
    end = time.time()
    print("Time taken by Hub (no conversion): ", end - start)

# Timing the Read Access of Hub converted to PyTorch

Since WebDataset is based on PyTorch and Hub offers PyTorch integration, it would be useful to compare Hub's performance when converted to PyTorch as well.

In [None]:
def time_hub_torch(tag, batch_size=64):
    dataset = hub.Dataset(tag)
    dataset = dataset.to_pytorch()
    dataset = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size
    )
    start = time.time()
    for batch in dataset:
        x = batch["image"]
        y = batch["label"]
    end = time.time()
    print("Time taken by Hub converted to PyTorch: ", end - start)

## Running the Experiment

Now that we have all the utility functions, we define the parameters we want to test with and can run the functions.

In [None]:
sharedurl = "/shards/imagenet-train-{000000..001281}.tar"             # Data for WebDataset, must be in torch.Datasets.ImageNet compatible format.
tag = "./data-hub"                                                    # Data for Hub

BATCH_SIZE = 1000                                                     # Batch Size for all the DataLoaders.  Can be changed.

In [None]:
time_webdataset(sharedurl, BATCH_SIZE)
time_hub(tag, BATCH_SIZE)
time_hub_torch(tag, BATCH_SIZE)