# Batch Inference on data in DeltaLake

In this tutorial, we showcase how to perform ML model batch inference on data in a DeltaLake table.

> **ML Model Batch Inference**
> 
> When we have a trained machine learning model, the next step is often to apply this model to a large amount of data. This involves efficiently loading the model into memory (potentially GPU memory) and then running data through the model to produce outputs.

To run this tutorial you will require AWS credentials to be correctly provisioned on your machine as all data is hosted in a requestor-pays bucket in AWS S3.

Let's get started!

## Provisioning Cloud Credentials

First, let's provision credentials to Daft! We can do so using the ``boto3`` library, and creating a Daft {class}`IOConfig <daft.io.IOConfig>` object like so:

In [2]:
import boto3
import daft

session = boto3.session.Session()
creds = session.get_credentials()
io_config = daft.io.IOConfig(
    s3=daft.io.S3Config(
        access_key=creds.secret_key,
        key_id=creds.access_key,
        session_token=creds.token,
        region_name="us-west-2",
    )
)

## Retrieving Data

Now we're ready to read data from our DeltaLake table!

We've hosted a 10k row sample of the validation set of imagenet for you to try this out.

Simply pass in the ``IOConfig`` that we previously created to the call in order to ensure that we can access the data.

In [3]:
df = daft.read_delta_lake("s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/", io_config=io_config)
df

folder Utf8,filename Utf8,source Struct[database: Utf8],"size Struct[depth: Utf8, height: Utf8, width: Utf8]",segmented Utf8,"object List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]"


For this demo, we're running this on our local machine and thus will be limiting the total amount of data to 100.

In [3]:
df = df.limit(100)
df = df.select("folder", "filename", "object")
df.collect()

ScanWithTask-LocalLimit-LocalLimit-Project [Stage:1]:   0%|          | 0/1 [00:00<?, ?it/s]

folder Utf8,filename Utf8,"object List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]"
val,ILSVRC2012_val_00000001,"[{bndbox: {xmax: 441, xmin: 111, ymax: 193, ymin: 108, }, difficult: 0, name: n01751748, pose: Unspecified, truncated: 0, }]"
val,ILSVRC2012_val_00000002,"[{bndbox: {xmax: 499, xmin: 45, ymax: 162, ymin: 49, }, difficult: 0, name: n09193705, pose: Unspecified, truncated: 0, }, {bndbox: {xmax: 437, xmin: 2, ymax: 207, ymin: 69, }, difficult: 0, name: n09193705, pose: Unspecified, truncated: 0, }]"
val,ILSVRC2012_val_00000003,"[{bndbox: {xmax: 385, xmin: 38, ymax: 373, ymin: 19, }, difficult: 0, name: n02105855, pose: Unspecified, truncated: 0, }]"
val,ILSVRC2012_val_00000004,"[{bndbox: {xmax: 441, xmin: 94, ymax: 284, ymin: 15, }, difficult: 0, name: n04263257, pose: Unspecified, truncated: 0, }]"
val,ILSVRC2012_val_00000005,"[{bndbox: {xmax: 425, xmin: 17, ymax: 332, ymin: 1, }, difficult: 0, name: n03125729, pose: Unspecified, truncated: 0, }]"
val,ILSVRC2012_val_00000006,"[{bndbox: {xmax: 358, xmin: 105, ymax: 279, ymin: 204, }, difficult: 0, name: n01735189, pose: Unspecified, truncated: 0, }]"
val,ILSVRC2012_val_00000007,"[{bndbox: {xmax: 498, xmin: 89, ymax: 268, ymin: 75, }, difficult: 0, name: n02346627, pose: Unspecified, truncated: 0, }]"
val,ILSVRC2012_val_00000008,"[{bndbox: {xmax: 181, xmin: 14, ymax: 328, ymin: 163, }, difficult: 0, name: n02776631, pose: Unspecified, truncated: 0, }, {bndbox: {xmax: 331, xmin: 176, ymax: 223, ymin: 81, }, difficult: 0, name: n02776631, pose: Unspecified, truncated: 0, }, {bndbox: {xmax: 236, xmin: 77, ymax: 155, ymin: 2, }, difficult: 0, name: n02776631, pose: Unspecified, truncated: 0, }, {bndbox: {xmax: 355, xmin: 163, ymax: 374, ymin: 219, }, difficult: 0, name: n02776631, pose: Unspecified, truncated: 0, }]"


## Retrieving Images

Let's now resolve the images to their URLs, and start downloading/decoding them into images in our dataframe!

In [4]:
df = df.with_column(
    "image_url",
    "s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/" + df["filename"] + ".jpeg"
)
df = df.with_column("image", df["image_url"].url.download().image.decode())

We also want to do a little preprocessing on our images to get them all into the same size. We can do this with the {meth}`.image.resize <daft.expressions.expressions.ExpressionImageNamespace.resize>` method!

In [5]:
df = df.with_column("image_resized_small", df["image"].image.resize(32, 32))
df = df.with_column("image_resized_large", df["image"].image.resize(256, 256))
df.show(4)

folder Utf8,filename Utf8,"object List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]",image_url Utf8,image Image[MIXED],image_resized_small Image[MIXED],image_resized_large Image[MIXED]
val,ILSVRC2012_val_00000001,"[{bndbox: {xmax: 441, xmin: 111, ymax: 193, ymin: 108, }, difficult: 0, name: n01751748, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000001.jpeg,,,
val,ILSVRC2012_val_00000002,"[{bndbox: {xmax: 499, xmin: 45, ymax: 162, ymin: 49, }, difficult: 0, name: n09193705, pose: Unspecified, truncated: 0, }, {bndbox: {xmax: 437, xmin: 2, ymax: 207, ymin: 69, }, difficult: 0, name: n09193705, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000002.jpeg,,,
val,ILSVRC2012_val_00000003,"[{bndbox: {xmax: 385, xmin: 38, ymax: 373, ymin: 19, }, difficult: 0, name: n02105855, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000003.jpeg,,,
val,ILSVRC2012_val_00000004,"[{bndbox: {xmax: 441, xmin: 94, ymax: 284, ymin: 15, }, difficult: 0, name: n04263257, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000004.jpeg,,,


## Running Batch Inference

Great! We now have our images nicely preprocessed, and are ready to run batch inference on them.

Let's run a simple ResNet image classifier on each image's "high-resolution" and "low-resolution" variant, to see how sensitive our model is to the resolution of the image!

First off, we define a "Stateful UDF" that will initialize our model once in the ``__init__`` method, and then use the same model across multiple invocations on different partitions of data.

In [7]:
import daft
import numpy as np
import torch
from torchvision.models import resnet50, ResNet50_Weights

@daft.udf(return_dtype=daft.DataType.string())
class ClassifyImage:
    def __init__(self):
        weights = ResNet50_Weights.DEFAULT
        self.model = resnet50(weights=weights)
        self.model.eval()
        self.preprocess = weights.transforms()
        self.category_map = weights.meta["categories"]

    def __call__(self, images: daft.Series, shape: list[int, int, int]):
        if len(images) == 0:
            return []

        # Convert the Daft Series into a list of Numpy arrays
        data = images.cast(daft.DataType.tensor(daft.DataType.uint8(), tuple(shape))).to_pylist()

        # Convert the numpy arrays into a torch tensor
        images_array = torch.tensor(np.array(data)).permute((0, 3, 1, 2))

        # Run the model, and map results back to a human-readable string
        batch = self.preprocess(images_array)
        prediction = self.model(batch).softmax(0)
        class_ids = prediction.argmax(1)
        scores = prediction[:, class_ids]
        return [self.category_map[class_id] for class_id in class_ids]


To run our model on the dataframe, simply call the ``ClassifyImage`` function we defined earlier on the columns!

NOTE: If we wanted to ensure that our UDF will run with a GPU, we can specify:

```
df.with_column(..., resource_request=daft.ResourceRequest(num_gpus=1))
```

In [8]:
df = df.with_column("predictions_lowres", ClassifyImage(df["image_resized_small"], [32, 32, 3]))
df = df.with_column("predictions_highres", ClassifyImage(df["image_resized_large"], [256, 256, 3]))

In [9]:
df.show(4)

folder Utf8,filename Utf8,"object List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]",image_url Utf8,image Image[MIXED],image_resized_small Image[MIXED],image_resized_large Image[MIXED],predictions_lowres Utf8,predictions_highres Utf8
val,ILSVRC2012_val_00000001,"[{bndbox: {xmax: 441, xmin: 111, ymax: 193, ymin: 108, }, difficult: 0, name: n01751748, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000001.jpeg,,,,rock python,rock python
val,ILSVRC2012_val_00000003,"[{bndbox: {xmax: 385, xmin: 38, ymax: 373, ymin: 19, }, difficult: 0, name: n02105855, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000003.jpeg,,,,Shetland sheepdog,Shetland sheepdog
val,ILSVRC2012_val_00000004,"[{bndbox: {xmax: 441, xmin: 94, ymax: 284, ymin: 15, }, difficult: 0, name: n04263257, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000004.jpeg,,,,eggnog,soup bowl
val,ILSVRC2012_val_00000005,"[{bndbox: {xmax: 425, xmin: 17, ymax: 332, ymin: 1, }, difficult: 0, name: n03125729, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000005.jpeg,,,,packet,cradle


Pretty cool! looks like decreasing the resolution of the image too much does have a strong effect on the model's performance, as expected.

We can go ahead and show **just** the rows that have show this behavior. We will also need to filter for rows where the image does not have 3 channels because that will break our code.

Note that the following cell will now take a much longer time to run as we need to run the model on all the rows instead of just the first 4!

In [10]:
# Filter out images where the number of channels != 3
df = df.where(df["image"].apply(lambda img: img.shape[2] == 3, return_dtype=daft.DataType.bool()))

# Show only rows where the predictions on the low-res/high-res images don't match
df = df.where(df["predictions_lowres"] != df["predictions_highres"])

df.show(4)

folder Utf8,filename Utf8,"object List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]",image_url Utf8,image Image[MIXED],image_resized_small Image[MIXED],image_resized_large Image[MIXED],predictions_lowres Utf8,predictions_highres Utf8
val,ILSVRC2012_val_00000004,"[{bndbox: {xmax: 441, xmin: 94, ymax: 284, ymin: 15, }, difficult: 0, name: n04263257, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000004.jpeg,,,,eggnog,soup bowl
val,ILSVRC2012_val_00000005,"[{bndbox: {xmax: 425, xmin: 17, ymax: 332, ymin: 1, }, difficult: 0, name: n03125729, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000005.jpeg,,,,packet,cradle
val,ILSVRC2012_val_00000006,"[{bndbox: {xmax: 358, xmin: 105, ymax: 279, ymin: 204, }, difficult: 0, name: n01735189, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000006.jpeg,,,,golf ball,sidewinder
val,ILSVRC2012_val_00000007,"[{bndbox: {xmax: 498, xmin: 89, ymax: 268, ymin: 75, }, difficult: 0, name: n02346627, pose: Unspecified, truncated: 0, }]",s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000007.jpeg,,,,Madagascar cat,porcupine
