# Distributed model inference using TensorFlow Keras
From: https://docs.databricks.com/_static/notebooks/deep-learning/keras-metadata.html

In [1]:
import os
import shutil
import subprocess
import time
import pandas as pd
from PIL import Image
import numpy as np
import uuid
 
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
 
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.sql import SparkSession

2024-09-24 01:23:38.569048: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-24 01:23:38.585924: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-24 01:23:38.591041: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-24 01:23:38.603766: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
num_threads = 6

# Creating a local Spark session for demonstration, in case it hasn't already been created.

_config = {
    "spark.master": f"local[{num_threads}]",
    "spark.driver.host": "127.0.0.1",
    "spark.task.maxFailures": "1",
    "spark.driver.memory": "8g",
    "spark.executor.memory": "8g",
    "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
    "spark.sql.pyspark.jvmStacktrace.enabled": "true",
    "spark.sql.execution.arrow.pyspark.enabled": "true",
    "spark.python.worker.reuse": "true",
}
spark = SparkSession.builder.appName("spark-dl-example")
for key, value in _config.items():
    spark = spark.config(key, value)
spark = spark.getOrCreate()

sc = spark.sparkContext

24/09/24 01:23:41 WARN Utils: Your hostname, dgx2h0194.spark.sjc4.nvmetal.net resolves to a loopback address: 127.0.1.1; using 10.150.30.2 instead (on interface enp134s0f0np0)
24/09/24 01:23:41 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/24 01:23:41 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/09/24 01:23:42 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
24/09/24 01:23:42 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
24/09/24 01:23:42 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.
24/09/24 01:23:42 WARN Utils: Service 'SparkUI' could not bind on port 4043. Attempting port 4044.
24/09/24 01:23:42 WARN Utils: Service 'SparkUI' could not bind on p

In [3]:
# Enable GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [4]:
file_name = "image_data.parquet"
output_file_path = "predictions"

### Prepare trained model and data for inference

Load the ResNet-50 Model and broadcast the weights.

In [5]:
model = ResNet50()
bc_model_weights = sc.broadcast(model.get_weights())

2024-09-24 01:23:46.468094: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 29659 MB memory:  -> device: 0, name: Tesla V100-SXM3-32GB-H, pci bus id: 0000:34:00.0, compute capability: 7.0
2024-09-24 01:23:46.469480: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 30823 MB memory:  -> device: 1, name: Tesla V100-SXM3-32GB-H, pci bus id: 0000:36:00.0, compute capability: 7.0
2024-09-24 01:23:46.470743: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 30823 MB memory:  -> device: 2, name: Tesla V100-SXM3-32GB-H, pci bus id: 0000:39:00.0, compute capability: 7.0
2024-09-24 01:23:46.471975: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 30823 MB memory:  -> device: 3, name: Tesla V100-SXM3-32GB-H, pc

Load the data and save the datasets to one Parquet file.

In [6]:
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url,
                                   fname='flower_photos',
                                   untar=True)
data_dir = pathlib.Path(data_dir)

In [7]:
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

3670


In [8]:
import os
files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(data_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg']
files = files[:2048]
len(files)

2048

In [9]:
print(data_dir)

/rishic/.keras/datasets/flower_photos


In [10]:
image_data = []
for file in files:
    img = Image.open(file)
    img = img.resize([224, 224])
    data = np.asarray(img, dtype="float32").reshape([224*224*3])

    image_data.append({"data": data})

pandas_df = pd.DataFrame(image_data, columns=['data'])
pandas_df.to_parquet(file_name)
# os.makedirs(dbfs_file_path)
# shutil.copyfile(file_name, dbfs_file_path+file_name)

### Save Model


In [11]:
subprocess.call("rm -rf resnet50_model".split())

0

In [30]:
model.export("resnet50_model")

INFO:tensorflow:Assets written to: resnet50_model/assets


INFO:tensorflow:Assets written to: resnet50_model/assets


Saved artifact at 'resnet50_model'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 1000), dtype=tf.float32, name=None)
Captures:
  139905008456848: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139905008458960: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139905008458768: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139905008458576: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139905008457424: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139905008458000: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139905008461840: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139905008462992: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139905008463952: TensorSpec(shape=(), dtype=tf.resource, name=None)
  139905008463376: TensorSpec(shape=(), dtype=tf.resource, name=None)
  1399050084641

### Load the data into Spark DataFrames

In [14]:
from pyspark.sql.types import *
df = spark.read.parquet(file_name)
print(df.count())

2048


In [15]:
# spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")

In [16]:
assert len(df.head()) > 0, "`df` should not be empty" # This line will fail if the vectorized reader runs out of memory

                                                                                

### Model inference via pandas UDF

In [17]:
def parse_image(image_data):
    image = tf.image.convert_image_dtype(
        image_data, dtype=tf.float32) * (2. / 255) - 1
    image = tf.reshape(image, [224, 224, 3])
    return image

In [18]:
@pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR_ITER)
def predict_batch_udf(image_batch_iter):

    # Enable GPU memory growth to avoid CUDA OOM
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)

    batch_size = 64
    model = ResNet50(weights=None)
    model.set_weights(bc_model_weights.value)
    for image_batch in image_batch_iter:
        images = np.vstack(image_batch)
        dataset = tf.data.Dataset.from_tensor_slices(images)
        dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(
            5000).batch(batch_size)
        preds = model.predict(dataset)
        yield pd.Series(list(preds))



In [19]:
%%time
predictions_df = df.select(predict_batch_udf(col("data")).alias("prediction"))
predictions_df.show(truncate=120)

2024-09-24 01:25:01.047489: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-24 01:25:01.063458: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-24 01:25:01.068152: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-24 01:25:01.080010: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-24 01:25:07.837282: I tensorflow/core

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[4.7788864E-5, 3.036927E-4, 7.026324E-5, 1.772748E-4, 4.5683286E-5, 1.9182988E-4, 9.988368E-6, 7.264478E-5, 1.4183694...|
|[1.2047288E-5, 2.4665435E-4, 2.3863144E-4, 1.9301432E-4, 1.5375564E-4, 4.5055505E-5, 2.220773E-5, 3.791191E-4, 1.5702...|
|[1.3302326E-4, 2.696228E-4, 5.517897E-5, 9.9901976E-5, 4.7618698E-5, 4.4045786E-4, 6.8055174E-6, 3.486012E-5, 1.26733...|
|[1.5190376E-5, 2.9297185E-4, 1.17424854E-4, 8.6468535E-5, 7.027255E-5, 7.291867E-5, 1.1590379E-5, 2.905424E-4, 1.7357...|
|[1.1257283E-4, 2.5857892E-4, 5.5297343E-5, 1.0446069E-4, 4.6776848E-5, 3.4146357E-4, 6.6849643E-6, 3.6820922E-5, 1.15...|
|[9.95129E-5, 3.

In [20]:
%%time
predictions_df.write.mode("overwrite").parquet(output_file_path)

2024-09-24 01:25:52.662786: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-24 01:25:52.662786: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-24 01:25:52.662789: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-24 01:25:52.680249: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-24 01:25:52.680249: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory fo

CPU times: user 136 ms, sys: 104 ms, total: 240 ms
Wall time: 32.3 s


                                                                                

### Model inference using Spark DL API

In [21]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import struct, col
from pyspark.sql.types import ArrayType, FloatType

In [22]:
def predict_batch_fn():
    import tensorflow as tf
    from tensorflow.keras.applications.resnet50 import ResNet50

    # Enable GPU memory growth
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)

    model = ResNet50()
    def predict(inputs):
        inputs = inputs * (2. / 255) - 1
        return model.predict(inputs)
    return predict

In [23]:
classify = predict_batch_udf(predict_batch_fn,
                             input_tensor_shapes=[[224, 224, 3]],
                             return_type=ArrayType(FloatType()),
                             batch_size=50)

In [24]:
# spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")

In [25]:
df = spark.read.parquet("image_data.parquet")

In [26]:
%%time
# first pass caches model/fn
predictions = df.select(classify(struct("data")).alias("prediction"))
predictions.show(truncate=120)

I0000 00:00:1727141215.666478 3948946 service.cc:146] XLA service 0x7fbfe80d5110 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727141215.666502 3948946 service.cc:154]   StreamExecutor device (0): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141215.666506 3948946 service.cc:154]   StreamExecutor device (1): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141215.666511 3948946 service.cc:154]   StreamExecutor device (2): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141215.666515 3948946 service.cc:154]   StreamExecutor device (3): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141215.666520 3948946 service.cc:154]   StreamExecutor device (4): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141215.666523 3948946 service.cc:154]   StreamExecutor device (5): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141215.666528 3948946 service.cc:15

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[4.7788864E-5, 3.036927E-4, 7.026324E-5, 1.772748E-4, 4.5683286E-5, 1.9182988E-4, 9.988368E-6, 7.264478E-5, 1.4183694...|
|[1.2047288E-5, 2.4665435E-4, 2.3863144E-4, 1.9301432E-4, 1.5375564E-4, 4.5055505E-5, 2.220773E-5, 3.791191E-4, 1.5702...|
|[1.3302326E-4, 2.696228E-4, 5.517897E-5, 9.9901976E-5, 4.7618698E-5, 4.4045786E-4, 6.8055174E-6, 3.486012E-5, 1.26733...|
|[1.5190376E-5, 2.9297185E-4, 1.17424854E-4, 8.6468535E-5, 7.027255E-5, 7.291867E-5, 1.1590379E-5, 2.905424E-4, 1.7357...|
|[1.1257283E-4, 2.5857892E-4, 5.5297343E-5, 1.0446069E-4, 4.6776848E-5, 3.4146357E-4, 6.6849643E-6, 3.6820922E-5, 1.15...|
|[9.95129E-5, 3.

I0000 00:00:1727141217.755256 3948946 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step
                                                                                

In [27]:
%%time
predictions = df.select(classify("data").alias("prediction"))
predictions.show(truncate=120)

I0000 00:00:1727141235.787909 3948707 service.cc:146] XLA service 0x7fbfe00d63e0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727141235.787940 3948707 service.cc:154]   StreamExecutor device (0): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141235.787945 3948707 service.cc:154]   StreamExecutor device (1): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141235.787950 3948707 service.cc:154]   StreamExecutor device (2): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141235.787953 3948707 service.cc:154]   StreamExecutor device (3): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141235.787958 3948707 service.cc:154]   StreamExecutor device (4): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141235.787962 3948707 service.cc:154]   StreamExecutor device (5): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
I0000 00:00:1727141235.787967 3948707 service.cc:15

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[4.7788864E-5, 3.036927E-4, 7.026324E-5, 1.772748E-4, 4.5683286E-5, 1.9182988E-4, 9.988368E-6, 7.264478E-5, 1.4183694...|
|[1.2047288E-5, 2.4665435E-4, 2.3863144E-4, 1.9301432E-4, 1.5375564E-4, 4.5055505E-5, 2.220773E-5, 3.791191E-4, 1.5702...|
|[1.3302326E-4, 2.696228E-4, 5.517897E-5, 9.9901976E-5, 4.7618698E-5, 4.4045786E-4, 6.8055174E-6, 3.486012E-5, 1.26733...|
|[1.5190376E-5, 2.9297185E-4, 1.17424854E-4, 8.6468535E-5, 7.027255E-5, 7.291867E-5, 1.1590379E-5, 2.905424E-4, 1.7357...|
|[1.1257283E-4, 2.5857892E-4, 5.5297343E-5, 1.0446069E-4, 4.6776848E-5, 3.4146357E-4, 6.6849643E-6, 3.6820922E-5, 1.15...|
|[9.95129E-5, 3.

I0000 00:00:1727141237.865103 3948707 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step
                                                                                

In [28]:
%%time
predictions = df.select(classify(col("data")).alias("prediction"))
predictions.write.mode("overwrite").parquet(output_file_path + "_1")

2024-09-24 01:27:34.143820: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-24 01:27:34.161243: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-24 01:27:34.166472: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-24 01:27:34.178907: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-24 01:27:41.293232: I tensorflow/core

CPU times: user 178 ms, sys: 173 ms, total: 351 ms
Wall time: 43.3 s


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step 
                                                                                

### Using Triton Inference Server

Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [29]:
import os
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, struct
from pyspark.sql.types import ArrayType, FloatType

In [31]:
%%bash
# copy model to expected layout for Triton
rm -rf models
mkdir -p models/resnet50/1
cp -r resnet50_model models/resnet50/1/model.savedmodel

# add config.pbtxt
cp models_config/resnet50/config.pbtxt models/resnet50/config.pbtxt

#### Start Triton Server on each executor

In [32]:
num_executors = 1
triton_models_dir = "{}/models".format(os.getcwd())
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:24.08-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="512M",
            volumes={triton_models_dir: {"bind": "/models", "mode": "ro"}}
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)

    return [True]

nodeRDD.barrier().mapPartitions(start_triton).collect()

>>>> starting triton: a13519cbf860                                  (0 + 1) / 1]
                                                                                

[True]

#### Run inference

In [39]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool_),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            inputs = inputs * (2. / 255) - 1  # add normalization
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [40]:
from functools import partial

classify = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="resnet50"),
                             input_tensor_shapes=[[224, 224, 3]],
                             return_type=ArrayType(FloatType()),
                             batch_size=50)

In [41]:
# spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")

In [42]:
df = spark.read.parquet("image_data.parquet")

In [43]:
%%time
# first pass caches model/fn
predictions = df.select(classify(struct("data")).alias("prediction"))
predictions.show(truncate=120)



+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[4.7788824E-5, 3.0369216E-4, 7.026308E-5, 1.772744E-4, 4.5683184E-5, 1.9182982E-4, 9.988331E-6, 7.264462E-5, 1.418365...|
|[1.2047234E-5, 2.466537E-4, 2.3863082E-4, 1.9301384E-4, 1.5375504E-4, 4.5055393E-5, 2.2207663E-5, 3.7911904E-4, 1.570...|
|[1.3302344E-4, 2.6962307E-4, 5.5178945E-5, 9.990197E-5, 4.761874E-5, 4.4045786E-4, 6.8055238E-6, 3.486012E-5, 1.26733...|
|[1.5190398E-5, 2.92972E-4, 1.17424854E-4, 8.6468666E-5, 7.0272654E-5, 7.291892E-5, 1.1590384E-5, 2.9054214E-4, 1.7357...|
|[1.12572874E-4, 2.5857877E-4, 5.529736E-5, 1.04460785E-4, 4.6776888E-5, 3.4146404E-4, 6.684954E-6, 3.6820937E-5, 1.15...|
|[9.9512836E-5, 

                                                                                

In [44]:
%%time
predictions = df.select(classify("data").alias("prediction"))
predictions.show(truncate=120)



+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                              prediction|
+------------------------------------------------------------------------------------------------------------------------+
|[4.7788824E-5, 3.0369216E-4, 7.026308E-5, 1.772744E-4, 4.5683184E-5, 1.9182982E-4, 9.988331E-6, 7.264462E-5, 1.418365...|
|[1.2047234E-5, 2.466537E-4, 2.3863082E-4, 1.9301384E-4, 1.5375504E-4, 4.5055393E-5, 2.2207663E-5, 3.7911904E-4, 1.570...|
|[1.3302344E-4, 2.6962307E-4, 5.5178945E-5, 9.990197E-5, 4.761874E-5, 4.4045786E-4, 6.8055238E-6, 3.486012E-5, 1.26733...|
|[1.5190398E-5, 2.92972E-4, 1.17424854E-4, 8.6468666E-5, 7.0272654E-5, 7.291892E-5, 1.1590384E-5, 2.9054214E-4, 1.7357...|
|[1.12572874E-4, 2.5857877E-4, 5.529736E-5, 1.04460785E-4, 4.6776888E-5, 3.4146404E-4, 6.684954E-6, 3.6820937E-5, 1.15...|
|[9.9512836E-5, 

                                                                                

In [45]:
%%time
predictions = df.select(classify(col("data")).alias("prediction"))
predictions.write.mode("overwrite").parquet(output_file_path + "_2")



CPU times: user 125 ms, sys: 84.5 ms, total: 209 ms
Wall time: 29.5 s


                                                                                

#### Stop Triton Server on each executor

In [46]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.barrier().mapPartitions(stop_triton).collect()

>>>> stopping containers: ['a13519cbf860']
                                                                                

[True]

In [47]:
spark.stop()