In [0]:
import time
import torch
torch.__version__
torch.cuda.is_available()

Out[1]: True

In [0]:
print("Nums gpus per task: ", spark.conf.get("spark.task.resource.gpu.amount"))
print("Executor memory: ", spark.conf.get("spark.executor.memory"))

Nums gpus per task:  0.0833
Executor memory:  148728m


In [0]:
import pandas as pd
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights



In [0]:
# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

In [0]:
# Create and broadcast model state. Equivalent to AIR Checkpoint
model_state = resnet50(weights=ResNet50_Weights.DEFAULT).state_dict()
# sc is already initialized by Databricks. Broadcast the model state to all executors.
bc_model_state = sc.broadcast(model_state)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [0]:
input_data = spark.read.format("parquet").load("s3://air-example-data-2/10G-image-data-synthetic-raw-parquet")

In [0]:
# Preprocessing
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import ArrayType, FloatType

import numpy as np

import torch
import time

# Read documentation here: https://spark.apache.org/docs/3.0.1/sql-pyspark-pandas-with-arrow.html

@pandas_udf(ArrayType(FloatType()))
def preprocess(image: pd.Series) -> pd.Series:
    preprocess = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    print(f"number of images: {len(image)}")
    # Spark has no tensor support, so it flattens the image tensor to a single array during read.
    # Each image is represented as a flattened numpy array.
    # We have to reshape back to the original number of dimensions.
    # Need to convert to float dtype otherwise torchvision transforms will complain. The data is read as short (int16) by default
    batch_dim = len(image)
    numpy_batch = np.stack(image.values)
    reshaped_images = numpy_batch.reshape(batch_dim, 256, 256, 3).astype(float)
    
    torch_tensor = torch.Tensor(reshaped_images.transpose(0, 3, 1, 2))
    preprocessed_images = preprocess(torch_tensor).numpy()
    # Arrow only works with single dimension numpy arrays, so need to flatten the array before outputting it
    preprocessed_images = [image.flatten() for image in preprocessed_images]
    return pd.Series(preprocessed_images)

In [0]:
preprocessed_data = input_data.select(preprocess(col("image")))

In [0]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")

In [0]:
def predict_custom_batching(input_rdd_iter):
    with torch.inference_mode():
        model = resnet50()
        model.load_state_dict(bc_model_state.value)
        model = model.to(torch.device("cuda")) # Move model to GPU
        model.eval()
    
        input_batch = []
        output = []
        for image in input_rdd_iter:
            input_batch.append(image)

            if len(input_batch) == 1000:
                numpy_batch = np.array(input_batch)
                reshaped_images = numpy_batch.reshape(1000, 3, 224, 224)
                gpu_batch = torch.Tensor(reshaped_images).to(torch.device("cuda"))
                predictions = list(model(gpu_batch).cpu().numpy())
                assert len(predictions) == 1000
                output.extend(list(predictions))

        return output

In [0]:
from pyspark.resource.profile import ResourceProfileBuilder
from pyspark.resource.requests import TaskResourceRequests

task_res_req = TaskResourceRequests().cpus(int(spark.sparkContext.getConf().get("spark.task.cpus", "1")))
task_res_req.resource("gpu", 1)
res_profile = ResourceProfileBuilder().require(task_res_req).build

preprocessed_rdd = preprocessed_data.rdd.withResources(res_profile)
predictions = preprocessed_rdd.mapPartitions(predict_custom_batching).collect()

[0;31m---------------------------------------------------------------------------[0m
[0;31mPy4JJavaError[0m                             Traceback (most recent call last)
File [0;32m<command-848636385098289>:18[0m
[1;32m     15[0m res_profile [38;5;241m=[39m ResourceProfileBuilder()[38;5;241m.[39mrequire(task_res_req)[38;5;241m.[39mbuild
[1;32m     17[0m preprocessed_rdd [38;5;241m=[39m preprocessed_data[38;5;241m.[39mrdd[38;5;241m.[39mwithResources(res_profile)
[0;32m---> 18[0m predictions [38;5;241m=[39m preprocessed_rdd[38;5;241m.[39mmapPartitions(predict_custom_batching)[38;5;241m.[39mcollect()

File [0;32m/databricks/spark/python/pyspark/instrumentation_utils.py:48[0m, in [0;36m_wrap_function.<locals>.wrapper[0;34m(*args, **kwargs)[0m
[1;32m     46[0m start [38;5;241m=[39m time[38;5;241m.[39mperf_counter()
[1;32m     47[0m [38;5;28;01mtry[39;00m:
[0;32m---> 48[0m     res [38;5;241m=[39m [43mfunc[49m[43m([49m[38;5;241;43m*[39;49

In [0]:
start_time = time.time()
predictions.write.mode("overwrite").format("noop").save()
end_time = time.time()
print(f"Prediction took: {end_time-start_time} seconds")

assert predictions.count() == 16232

[0;31m---------------------------------------------------------------------------[0m
[0;31mPythonException[0m                           Traceback (most recent call last)
File [0;32m<command-848636385098290>:3[0m
[1;32m      1[0m start_time [38;5;241m=[39m time[38;5;241m.[39mtime()
[1;32m      2[0m predictions[38;5;241m.[39mpersist() [38;5;66;03m# Persist is a lazy operation- need to also have the line below[39;00m
[0;32m----> 3[0m predictions[38;5;241m.[39mwrite[38;5;241m.[39mmode([38;5;124m"[39m[38;5;124moverwrite[39m[38;5;124m"[39m)[38;5;241m.[39mformat([38;5;124m"[39m[38;5;124mnoop[39m[38;5;124m"[39m)[38;5;241m.[39msave()
[1;32m      4[0m end_time [38;5;241m=[39m time[38;5;241m.[39mtime()
[1;32m      5[0m [38;5;28mprint[39m([38;5;124mf[39m[38;5;124m"[39m[38;5;124mPrediction took: [39m[38;5;132;01m{[39;00mend_time[38;5;241m-[39mstart_time[38;5;132;01m}[39;00m[38;5;124m seconds[39m[38;5;124m"[39m)

File [0;32m/databric