In [None]:
import concurrent.futures
import requests
import threading
import PIL.Image
import ray.data
import boto3
import io
import torch
import torchvision
import numpy as np
from torch import hub

import daft
from daft.fields import DaftImageField
from daft.experimental.dataclasses import dataclass
from daft.datarepo import get_client
from daft.datarepo.query import functions as F

from typing import List

# Daft Benchmarking

Benchmarking the performance of Daft on a subset of the OpenImages dataset

The OpenImages dataset is hosted publicly in AWS through the `s3://open-images-dataset` bucket. However this bucket is slow to access because we share bandwidth with all other users of the data, and so we first copy data into our own AWS account/bucket using the command line:

```
aws s3 sync s3://open-images-dataset/validation s3://eventual-data-test-bucket/benchmarking/open-images-dataset/validation
```

## Ingesting Datarepo

Now that the data is in our own buckets, the first step in this notebook will ingest the data as a Daft Datarepo.

We do the following:

1. Get a list of all the images in our S3 bucket
2. Use a ThreadPoolExecutor to download these JPEG images efficiently into a Ray cluster
3. Resize the images to 256 x 256
4. Write the images to a Datarepo

In [None]:
daft.init()

In [None]:
BUCKET = "eventual-data-test-bucket"
DATA_PREFIX = "benchmarking/open-images-dataset/validation"

In [None]:
%%time

s3_paginator = boto3.client("s3").get_paginator('list_objects_v2')

objs = []
for page in s3_paginator.paginate(Bucket=BUCKET, Prefix=DATA_PREFIX):
    for content in page.get('Contents', ()):
        objs.append(content['Key'])
        
print(f"Number of objects: {len(objs)}")

In [None]:
@dataclass
class OpenImageRaw:
    key: str
    img: PIL.Image.Image = DaftImageField()


def download_batch(batch: List[str]) -> List[bytes]:
    def download_single(key: str) -> bytes:
        local = threading.local()
        if "boto_session" not in local.__dict__:
            local.boto_session = boto3.session.Session()
        s3 = local.boto_session.client('s3')
        response = s3.get_object(Bucket=BUCKET, Key=key)
        body = response["Body"]
        contents = body.read()
        body.close()
        return (key, contents)
    
    with concurrent.futures.ThreadPoolExecutor() as executor : 
        return [res for res in executor.map(download_single, batch)]


def resized_pil_image(payload: bytes, size:int=256) -> PIL.Image.Image:
    """Loads a payload of bytes as a PIL image and resizes it to specified given size"""
    with io.BytesIO(payload) as f:
        try:
            img = PIL.Image.open(f)
            img = img.resize((size,size))
            img = img.convert("RGB")
        except Exception as e:
            img = PIL.Image.new("RGB", (size, size))
        return img

In [None]:
%%time

ds = ray.data.from_items(objs)
ds = ds.map_batches(
    download_batch
).map(
    lambda tup: OpenImageRaw(key=tup[0], img=resized_pil_image(tup[1]))
)

In [None]:
client = get_client()
client.create("open-images-validation-resized", dtype=OpenImageRaw, exists_ok=True)

In [None]:
datarepo = client.from_id("open-images-validation-resized")

In [None]:
written_files = datarepo.overwrite(ds, rows_per_partition=1024)

## Running queries/processing

In [None]:
@F.batch_func(batch_size=8)
class BatchInferModel:
    def __init__(self):
        """
        Here we init our model as well as needed data transforms
        """
        hub.set_dir("/tmp/.torchcache")
        self.model_name = "resnet18"
        model = torchvision.models.resnet18(pretrained=True).eval()
        self.feature_extractor = torchvision.models.feature_extraction.create_feature_extractor(
            model=model, 
            return_nodes={'avgpool': 'embedding'}
        )
        self.to_tensor = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            )]
        )
    
    def prepare_batch(self, image_data: List[PIL.Image.Image]) -> torch.Tensor:
        """
        Here we convert our PIL image to a normalized tensor
        """
        return torch.stack([self.to_tensor(img) for img in image_data])
    
    def __call__(self, image_data: List[PIL.Image.Image]) -> List[np.ndarray]:
        """
        Here we extract our embedding with resnet 18
        """
        with torch.no_grad():
            tensor = self.prepare_batch(image_data)
            embedding =  self.feature_extractor(tensor.float())['embedding'].view(len(image_data), -1)
            np_embedding = embedding.cpu().numpy()
            dim = np_embedding.shape[1]
            per_image_embedding = np.vsplit(np_embedding, np.arange(1, len(image_data)))
            return per_image_embedding

In [None]:
datarepo = client.from_id("open-images-validation-resized")

In [None]:
query = datarepo.query(OpenImageRaw).with_column("embeddings", BatchInferModel("img"))

In [None]:
%%time

ds = query.execute()