In [0]:
input_data = spark.read.format("parquet").load("s3://air-example-data-2/10G-image-data-synthetic-raw-parquet")

In [0]:
import time
import torch
torch.__version__
torch.cuda.is_available()

True

In [0]:
print("Executor memory: ", spark.conf.get("spark.executor.memory"))

Executor memory:  148728m


In [0]:
import pandas as pd
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights



In [0]:
# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

In [0]:
# Create and broadcast model state. Equivalent to AIR Checkpoint
model_state = resnet50(weights=ResNet50_Weights.DEFAULT).state_dict()
# sc is already initialized by Databricks. Broadcast the model state to all executors.
bc_model_state = sc.broadcast(model_state)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [0]:
# Preprocessing
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import ArrayType, FloatType
from typing import Iterator

import numpy as np

import torch
import time
import threading
import queue

preprocess_image = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


# Read documentation here: https://spark.apache.org/docs/3.0.1/sql-pyspark-pandas-with-arrow.html


def preprocess(image_batch):
    image_batch = image_batch.reshape(-1, 256, 256, 3)
    image_batch = torch.permute(image_batch, (0, 3, 1, 2))
    return preprocess_image(image_batch)


@pandas_udf(ArrayType(FloatType()))
def resnet_predict(pandas_series_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
    with torch.inference_mode():
        model = resnet50()
        model.load_state_dict(bc_model_state.value)
        model = model.to(torch.device("cuda")) # Move model to GPU
        model.eval()

        for pandas_series in pandas_series_iter:
            image_batch = torch.tensor(np.stack(pandas_series.values).astype(np.uint8))
            # change uint 0 ~ 255 range values to 0 ~ 1 range float32 values
            image_batch = image_batch / np.float32(256)

            image_batch = preprocess(image_batch)
            
            
            image_batch = image_batch.to(torch.device("cuda"))
            
            predictions = list(model(image_batch).cpu().numpy())
            
            yield pd.Series(predictions)


In [0]:
predictions = input_data.select(resnet_predict(col("image")))

In [0]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "64")

In [0]:

run_times = []
run_N = 6
for i in range(run_N):
    start_time = time.time()
    predictions.write.mode("overwrite").format("noop").save()
    end_time = time.time()
    run_times.append(end_time-start_time)

assert input_data.count() == 16232

In [0]:
print("Run times: ", run_times)
print(f"Averge Prediction took: {sum(run_times) / run_N} seconds")

Run times:  [126.96703696250916, 106.93883275985718, 99.0119571685791, 95.7514579296112, 90.47398567199707, 90.08183360099792]
Averge Prediction took: 101.53751734892528 seconds


In [0]:
spark