<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark Tensorflow Inference

### Flower Recognition with Keras Resnet50

In this notebook, we demonstrate distribute inference with Resnet50 on the Databricks flower photos dataset.  
From: https://docs.databricks.com/_static/notebooks/deep-learning/keras-metadata.html

Note that cuFFT/cuDNN/cuBLAS registration errors are expected with `tf=2.17.0` and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)  

In [1]:
import os
import shutil
import subprocess
import time
import json
import pandas as pd
from PIL import Image
import numpy as np
import uuid
 
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50

2024-12-10 21:25:21.839435: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-10 21:25:21.846809: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-10 21:25:21.854792: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-10 21:25:21.857325: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-10 21:25:21.863686: I tensorflow/core/platform/cpu_feature_guar

In [2]:
os.mkdir('models') if not os.path.exists('models') else None

In [3]:
print(tf.__version__)

# Enable GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

2.17.0


## PySpark

In [4]:
from pyspark.sql.functions import col, struct, pandas_udf, PandasUDFType
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark import SparkConf
from typing import Iterator, Tuple

Check the cluster environment to handle any platform-specific Spark configurations.

In [5]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

In [6]:
conf = SparkConf()

if on_standalone:
    conda_env = os.environ.get("CONDA_PREFIX")
    # Point PyTriton to correct libpython3.11.so:
    conf.set("spark.executorEnv.LD_LIBRARY_PATH", f"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH")
    source = "/usr/lib/x86_64-linux-gnu/libstdc++.so.6"
    target = f"{conda_env}/lib/libstdc++.so.6"
    try:
        if os.path.islink(target) or os.path.exists(target):
            os.remove(target)
        os.symlink(source, target)
    except OSError as e:
        print(f"Error creating symlink: {e}")
        
    if 'spark' not in globals():
        import socket
        # If Spark was not started with Jupyter, attach to local standalone
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
elif on_dataproc:
    # Point PyTriton to correct libpython3.11.so:
    conda_lib_path="/opt/conda/miniconda3/lib"
    conf.set("spark.executorEnv.LD_LIBRARY_PATH", f"{conda_lib_path}:$LD_LIBRARY_PATH") 

conf.set("spark.driver.memory", "8g")
conf.set("spark.executor.memory", "8g")
conf.set("spark.executor.cores", "8")
conf.set("spark.task.resource.gpu.amount", "0.125")
conf.set("spark.executor.resource.gpu.amount", "1")
conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")

spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

24/12/10 21:25:23 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
24/12/10 21:25:23 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/10 21:25:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Define the input and output directories.

In [7]:
file_name = "datasets/image_data.parquet"

# DATABRICKS SPECIFIC STUFF:
# temp_path = "tmp/flowers_{uuid}".format(uuid=str(uuid.uuid1()))
# dbfs_file_path = "/dbfs/{}/".format(temp_path)
# local_file_path = "/{}/image_data.parquet".format(temp_path)
# output_file_path = "/{}/predictions".format(temp_path)
output_file_path = "predictions/predictions"

### Prepare trained model and data for inference

Load the ResNet-50 Model and broadcast the weights.

In [8]:
model = ResNet50()
bc_model_weights = sc.broadcast(model.get_weights())

2024-12-10 21:25:24.616108: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46350 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6


Load the data and save the datasets to one Parquet file.

In [9]:
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url,
                                   fname='flower_photos',
                                   untar=True)
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.jpg')))
print(f"Image count: {image_count}")

Image count: 3670


In [10]:
import os
files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(data_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg']
files = files[:2048]

In [12]:
image_data = []
for file in files:
    img = Image.open(file)
    img = img.resize([224, 224])
    data = np.asarray(img, dtype="float32").reshape([224*224*3])

    image_data.append({"data": data})

pandas_df = pd.DataFrame(image_data, columns=['data'])
pandas_df.to_parquet(file_name)
# os.makedirs(dbfs_file_path)
# shutil.copyfile(file_name, dbfs_file_path+file_name)

### Save Model


In [13]:
model_path = 'models/resnet50_model.keras'
model.save(model_path)
print(f"Resnet50 Model saved to {model_path}.")

Resnet50 Model saved to models/resnet50_model.keras.


### Load the data into Spark DataFrames

In [14]:
df = spark.read.parquet(file_name)
print(df.count())

2048


In [None]:
# Decrease the batch size of the Arrow reader to avoid OOM errors on smaller instance types.
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")

In [16]:
assert len(df.head()) > 0, "`df` should not be empty" # This line will fail if the vectorized reader runs out of memory

                                                                                

### Model inference via Pandas UDF

Define the function to parse the input data.

In [17]:
def parse_image(image_data):
    image = tf.image.convert_image_dtype(
        image_data, dtype=tf.float32) * (2. / 255) - 1
    image = tf.reshape(image, [224, 224, 3])
    return image

Define the function for model inference.

In [18]:
@pandas_udf(ArrayType(FloatType()))
def pandas_predict_udf(iter: Iterator[Tuple[pd.Series]]) -> Iterator[pd.Series]:

    # Enable GPU memory growth to avoid CUDA OOM
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)

    batch_size = 64
    model = ResNet50(weights=None)
    model.set_weights(bc_model_weights.value)
    for image_batch in iter:
        images = np.vstack(image_batch)
        dataset = tf.data.Dataset.from_tensor_slices(images)
        dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(
            5000).batch(batch_size)
        preds = model.predict(dataset)
        yield pd.Series(list(preds))

Run model inference and save the results to Parquet.

In [19]:
%%time
predictions_1 = df.select(pandas_predict_udf(col("data")).alias("prediction"))
results = predictions_1.collect()



CPU times: user 48 ms, sys: 25.7 ms, total: 73.7 ms
Wall time: 18.9 s


                                                                                

In [20]:
predictions_1.show(truncate=100)



+----------------------------------------------------------------------------------------------------+
|                                                                                          prediction|
+----------------------------------------------------------------------------------------------------+
|[1.2938889E-4, 2.4666305E-4, 6.765791E-5, 1.2263245E-4, 5.7486624E-5, 3.9616702E-4, 7.0566134E-6,...|
|[4.4501914E-5, 3.5403698E-4, 4.6702033E-5, 8.102543E-5, 3.1704556E-5, 1.9194305E-4, 7.905952E-6, ...|
|[1.05672516E-4, 2.2686279E-4, 3.0055395E-5, 6.523785E-5, 2.352077E-5, 3.7122983E-4, 3.3315896E-6,...|
|[2.0331638E-5, 2.2746396E-4, 7.828012E-5, 6.986782E-5, 4.705316E-5, 9.80732E-5, 5.561918E-6, 2.35...|
|[1.130241E-4, 2.3187004E-4, 5.296914E-5, 1.0871329E-4, 4.027478E-5, 3.7183522E-4, 5.5931855E-6, 3...|
|[9.094467E-5, 2.06384E-4, 4.514821E-5, 7.665891E-5, 3.2262324E-5, 3.3875552E-4, 3.831814E-6, 4.18...|
|[1.07847634E-4, 3.7848807E-4, 7.660533E-5, 1.2446754E-4, 4.7595917E-5, 3

                                                                                

In [21]:
predictions_1.write.mode("overwrite").parquet(output_file_path + "_1")

                                                                                

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):

- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays 
- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function

In [22]:
def predict_batch_fn():
    import tensorflow as tf
    from tensorflow.keras.applications.resnet50 import ResNet50

    # Enable GPU memory growth
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)

    model = ResNet50()
    def predict(inputs):
        inputs = inputs * (2. / 255) - 1
        return model.predict(inputs)
    return predict

In [23]:
classify = predict_batch_udf(predict_batch_fn,
                             input_tensor_shapes=[[224, 224, 3]],
                             return_type=ArrayType(FloatType()),
                             batch_size=50)

In [24]:
df = spark.read.parquet("image_data.parquet")

In [25]:
%%time
# first pass caches model/fn
predictions_2 = df.select(classify(struct("data")).alias("prediction"))
results = predictions_2.collect()



CPU times: user 68 ms, sys: 28.6 ms, total: 96.6 ms
Wall time: 16.2 s


                                                                                

In [26]:
%%time
predictions_2 = df.select(classify("data").alias("prediction"))
results = predictions_2.collect()



CPU times: user 61.9 ms, sys: 17.9 ms, total: 79.8 ms
Wall time: 10.1 s


                                                                                

In [27]:
%%time
predictions_2 = df.select(classify(col("data")).alias("prediction"))
results = predictions_2.collect()



CPU times: user 56.5 ms, sys: 17.8 ms, total: 74.3 ms
Wall time: 16.3 s


                                                                                

In [28]:
predictions_2.show(truncate=100)



+----------------------------------------------------------------------------------------------------+
|                                                                                          prediction|
+----------------------------------------------------------------------------------------------------+
|[1.296447E-4, 2.465122E-4, 6.7463385E-5, 1.2231144E-4, 5.731739E-5, 3.9644213E-4, 7.0297688E-6, 4...|
|[4.4481887E-5, 3.526653E-4, 4.683818E-5, 8.1168495E-5, 3.178377E-5, 1.9188467E-4, 7.885617E-6, 1....|
|[1.05946536E-4, 2.2744355E-4, 3.0219735E-5, 6.548672E-5, 2.3649674E-5, 3.7177472E-4, 3.353236E-6,...|
|[2.0392703E-5, 2.2817637E-4, 7.840744E-5, 6.9875685E-5, 4.702542E-5, 9.8244425E-5, 5.5829764E-6, ...|
|[1.1312391E-4, 2.31244E-4, 5.279228E-5, 1.0859927E-4, 4.0202678E-5, 3.721753E-4, 5.563934E-6, 3.4...|
|[9.126345E-5, 2.0679034E-4, 4.5165678E-5, 7.679106E-5, 3.234611E-5, 3.3994843E-4, 3.84E-6, 4.1930...|
|[1.07930486E-4, 3.7741542E-4, 7.613175E-5, 1.2414041E-4, 4.7409427E-5, 3

                                                                                

In [29]:
predictions_2.write.mode("overwrite").parquet(output_file_path + "_2")

                                                                                

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-pytriton.png" alt="drawing" width="700"/>

In [30]:
from functools import partial

In [None]:
def triton_server():
    import signal
    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.applications import ResNet50
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton
    from pyspark import TaskContext

    print(f"SERVER: Initializing ResNet on worker {TaskContext.get().partitionId()}.")

    # Enable GPU memory growth
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)
    
    model = ResNet50()
    normalization_layer = tf.keras.layers.Rescaling(scale=2./255, offset=-1)

    @batch
    def _infer_fn(**inputs):
        images = inputs["images"]
        normalized_images = normalization_layer(images)
        return {
            "preds": model.predict(normalized_images),
        }

    with Triton() as triton:

        triton.bind(
            model_name="ResNet50",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="images", dtype=np.float32, shape=(224, 224, 3)),
            ],
            outputs=[
                Tensor(name="preds", dtype=np.float32, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=100,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def stop_triton(signum, frame):
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

def start_triton(url, model_name):
    import socket
    import psutil
    from multiprocessing import Process
    from pytriton.client import ModelClient

    for conn in psutil.net_connections(kind="inet"):
        if conn.laddr.port == 8001:
            print(f"Process {conn.pid} is already running on port 8001. Please stop it before starting a new one.")
            return []

    hostname = socket.gethostname()
    process = Process(target=triton_server)
    process.start()

    client = ModelClient(url, model_name)
    ready = False
    while not ready:
        try:
            client.wait_for_server(5)
            ready = True
        except Exception as e:
            print(f"Waiting for server to be ready: {e}")
    
    return [(hostname, process.pid)]

#### Start Triton servers

To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU.  

In [32]:
def _use_stage_level_scheduling(spark, rdd):

    if spark.version < "3.4.0":
        raise Exception("Stage-level scheduling is not supported in Spark < 3.4.0")

    executor_cores = spark.conf.get("spark.executor.cores")
    assert executor_cores is not None, "spark.executor.cores is not set"
    executor_gpus = spark.conf.get("spark.executor.resource.gpu.amount")
    assert executor_gpus is not None and int(executor_gpus) <= 1, "spark.executor.resource.gpu.amount must be set and <= 1"

    from pyspark.resource.profile import ResourceProfileBuilder
    from pyspark.resource.requests import TaskResourceRequests

    spark_plugins = spark.conf.get("spark.plugins", " ")
    assert spark_plugins is not None
    spark_rapids_sql_enabled = spark.conf.get("spark.rapids.sql.enabled", "true")
    assert spark_rapids_sql_enabled is not None

    task_cores = (
        int(executor_cores)
        if "com.nvidia.spark.SQLPlugin" in spark_plugins
        and "true" == spark_rapids_sql_enabled.lower()
        else (int(executor_cores) // 2) + 1
    )

    task_gpus = 1.0
    treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
    rp = ResourceProfileBuilder().require(treqs).build
    print(f"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})")

    return rdd.withResources(rp)

**Specify the number of nodes in the cluster.**  
Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 2 nodes by default. 

In [33]:
num_nodes = 1  # Change based on cluster setup

In [34]:
url = "localhost"
model_name = "ResNet50"

sc = spark.sparkContext
nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)

Reqesting stage-level resources: (cores=5, gpu=1.0)


In [35]:
pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(url, model_name)).collectAsMap()
print("Triton Server PIDs:\n", json.dumps(pids, indent=4))

[Stage 17:>                                                         (0 + 1) / 1]

Triton Server PIDs:
 {
    "cb4ae00-lcedt": 1206043
}


                                                                                

#### Define client function

In [36]:
def triton_fn(url, model_name, init_timeout_s):
    import numpy as np
    from pytriton.client import ModelClient

    print(f"CLIENT: Connecting to {model_name} at {url}")

    def infer_batch(inputs):
        with ModelClient(url, model_name, init_timeout_s=init_timeout_s) as client:
            result_data = client.infer_batch(inputs)
            return result_data["preds"]
            
    return infer_batch

#### Load DataFrame

In [37]:
df = spark.read.parquet("image_data.parquet")

#### Run inference

In [38]:
from functools import partial

classify = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name, init_timeout_s=500),
                             input_tensor_shapes=[[224, 224, 3]],
                             return_type=ArrayType(FloatType()),
                             batch_size=50)

In [39]:
%%time
# first pass caches model/fn
predictions_3 = df.select(classify(struct("data")).alias("prediction"))
results = predictions_3.collect()



CPU times: user 116 ms, sys: 45.5 ms, total: 162 ms
Wall time: 18.8 s


                                                                                

In [40]:
%%time
predictions_3 = df.select(classify("data").alias("prediction"))
results = predictions_3.collect()



CPU times: user 63.6 ms, sys: 24.7 ms, total: 88.3 ms
Wall time: 12.4 s


                                                                                

In [41]:
%%time
predictions_3 = df.select(classify(col("data")).alias("prediction"))
results = predictions_3.collect()



CPU times: user 49.2 ms, sys: 18.4 ms, total: 67.7 ms
Wall time: 12.7 s


                                                                                

In [42]:
predictions_3.show(truncate=100)



+----------------------------------------------------------------------------------------------------+
|                                                                                          prediction|
+----------------------------------------------------------------------------------------------------+
|[1.296447E-4, 2.465122E-4, 6.7463385E-5, 1.2231144E-4, 5.731739E-5, 3.9644213E-4, 7.0297688E-6, 4...|
|[4.4481887E-5, 3.526653E-4, 4.683818E-5, 8.1168495E-5, 3.178377E-5, 1.9188467E-4, 7.885617E-6, 1....|
|[1.05946536E-4, 2.2744355E-4, 3.0219735E-5, 6.548672E-5, 2.3649674E-5, 3.7177472E-4, 3.353236E-6,...|
|[2.0392703E-5, 2.2817637E-4, 7.840744E-5, 6.9875685E-5, 4.702542E-5, 9.8244425E-5, 5.5829764E-6, ...|
|[1.1312391E-4, 2.31244E-4, 5.279228E-5, 1.0859927E-4, 4.0202678E-5, 3.721753E-4, 5.563934E-6, 3.4...|
|[9.126345E-5, 2.0679034E-4, 4.5165678E-5, 7.679106E-5, 3.234611E-5, 3.3994843E-4, 3.84E-6, 4.1930...|
|[1.07930486E-4, 3.7741542E-4, 7.613175E-5, 1.2414041E-4, 4.7409427E-5, 3

                                                                                

In [43]:
predictions_3.write.mode("overwrite").parquet(output_file_path + "_3")

                                                                                

#### Stop Triton Server on each executor

In [44]:
def stop_triton(pids):
    import os
    import socket
    import signal
    import time 
    
    hostname = socket.gethostname()
    pid = pids.get(hostname, None)
    assert pid is not None, f"Could not find pid for {hostname}"
    os.kill(pid, signal.SIGTERM)
    time.sleep(7)
    
    for _ in range(5):
        try:
            os.kill(pid, 0)
        except OSError:
            return [True]
        time.sleep(5)

    return [False]

shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)
shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()

Reqesting stage-level resources: (cores=5, gpu=1.0)


                                                                                

[True]

In [45]:
spark.stop()