<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# Pyspark TensorFlow Inference

### Classification using Keras Preprocessing Layers

In this notebook, we demonstrate distributed inference using Keras preprocessing layers to classify structured data.  
From: https://www.tensorflow.org/tutorials/structured_data/preprocessing_layers

Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)  

In [1]:
import os
import shutil
import numpy as np
import pandas as pd
import tensorflow as tf

from tensorflow.keras import layers

2025-02-04 13:59:29.670948: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-04 13:59:29.679838: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-04 13:59:29.689914: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-04 13:59:29.692851: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-04 13:59:29.700499: I tensorflow/core/platform/cpu_feature_guar

In [2]:
os.mkdir('models') if not os.path.exists('models') else None

In [None]:
print(tf.__version__)

# Enable GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

2.17.0


I0000 00:00:1738706370.524690 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1738706370.550329 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1738706370.553239 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355


#### Download dataset

Download the PetFinder dataset from Kaggle, which where each row describes a pet and the goal is to predict adoption speed.

In [4]:
import pathlib
import os
dataset_url = 'http://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip'

data_dir = tf.keras.utils.get_file('petfinder_mini.zip', dataset_url, extract=True, cache_dir='.')
data_dir = pathlib.Path(data_dir)
try:
    # pet-finder-mini might be under a parent a directory petfinder_mini_extracted. Check if this is the case:
    dataset = os.path.join(os.path.dirname(data_dir), 'petfinder_mini_extracted/petfinder-mini/petfinder-mini.csv')
    dataframe = pd.read_csv(dataset)
except:
    dataset = os.path.join(os.path.dirname(data_dir), 'petfinder-mini/petfinder-mini.csv')
    dataframe = pd.read_csv(dataset)

Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip
[1m1668792/1668792[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [5]:
dataframe.head()

Unnamed: 0,Type,Age,Breed1,Gender,Color1,Color2,MaturitySize,FurLength,Vaccinated,Sterilized,Health,Fee,Description,PhotoAmt,AdoptionSpeed
0,Cat,3,Tabby,Male,Black,White,Small,Short,No,No,Healthy,100,Nibble is a 3+ month old ball of cuteness. He ...,1,2
1,Cat,1,Domestic Medium Hair,Male,Black,Brown,Medium,Medium,Not Sure,Not Sure,Healthy,0,I just found it alone yesterday near my apartm...,2,0
2,Dog,1,Mixed Breed,Male,Brown,White,Medium,Medium,Yes,No,Healthy,0,Their pregnant mother was dumped by her irresp...,7,3
3,Dog,4,Mixed Breed,Female,Black,Brown,Medium,Short,Yes,No,Healthy,150,"Good guard dog, very alert, active, obedience ...",8,2
4,Dog,1,Mixed Breed,Male,Black,No Color,Medium,Short,No,No,Healthy,0,This handsome yet cute boy is up for adoption....,3,2


### Prepare dataset

In [6]:
# In the original dataset, `'AdoptionSpeed'` of `4` indicates
# a pet was not adopted.
dataframe['target'] = np.where(dataframe['AdoptionSpeed']==4, 0, 1)

# Drop unused features.
dataframe = dataframe.drop(columns=['AdoptionSpeed', 'Description'])

In [None]:
train, val, test = np.split(dataframe.sample(frac=1), [int(0.8*len(dataframe)), int(0.9*len(dataframe))])

  return bound(*args, **kwds)


In [8]:
print(len(train), 'training examples')
print(len(val), 'validation examples')
print(len(test), 'test examples')

9229 training examples
1154 validation examples
1154 test examples


Create an input pipeline which converts each dataset into a tf.data.Dataset with shuffling and batching.

In [9]:
def df_to_dataset(dataframe, shuffle=True, batch_size=32):
    df = dataframe.copy()
    labels = df.pop('target')
    df = {key: value.to_numpy()[:,tf.newaxis] for key, value in dataframe.items()}
    ds = tf.data.Dataset.from_tensor_slices((dict(df), labels))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(dataframe))
    ds = ds.batch(batch_size)
    ds = ds.prefetch(batch_size)
    return ds

Check the format of the data returned by the pipeline:

In [None]:
batch_size = 5
train_ds = df_to_dataset(train, batch_size=batch_size)

I0000 00:00:1738706370.981571 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1738706370.984478 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1738706370.987280 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1738706371.105121 3686377 cuda_executor.cc:1015] successful NUMA node read from SysFS ha

(Note that OUT_OF_RANGE errors are safe to ignore: https://github.com/tensorflow/tensorflow/issues/62963).

In [None]:
[(train_features, label_batch)] = train_ds.take(1)
print('Every feature:', list(train_features.keys()))
print('A batch of ages:', train_features['Age'])
print('A batch of targets:', label_batch )

Every feature: ['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt', 'target']
A batch of ages: tf.Tensor(
[[ 4]
 [60]
 [24]
 [ 1]
 [ 2]], shape=(5, 1), dtype=int64)
A batch of targets: tf.Tensor([1 1 1 1 1], shape=(5,), dtype=int64)


2025-02-04 13:59:31.170523: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


### Apply Keras preprocessing layers

We'll define a normalization layer for numeric features, and a category encoding for categorical features.

In [12]:
def get_normalization_layer(name, dataset):
    # Create a Normalization layer for the feature.
    normalizer = layers.Normalization(axis=None)

    # Prepare a Dataset that only yields the feature.
    feature_ds = dataset.map(lambda x, y: x[name])

    # Learn the statistics of the data.
    normalizer.adapt(feature_ds)

    return normalizer

In [13]:
photo_count_col = train_features['PhotoAmt']
layer = get_normalization_layer('PhotoAmt', train_ds)
layer(photo_count_col)

2025-02-04 13:59:32.726183: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


<tf.Tensor: shape=(5, 1), dtype=float32, numpy=
array([[-0.19333968],
       [-0.19333968],
       [-0.19333968],
       [-0.51676387],
       [ 1.100357  ]], dtype=float32)>

In [14]:
def get_category_encoding_layer(name, dataset, dtype, max_tokens=None):
    # Create a layer that turns strings into integer indices.
    if dtype == 'string':
        index = layers.StringLookup(max_tokens=max_tokens)
    # Otherwise, create a layer that turns integer values into integer indices.
    else:
        index = layers.IntegerLookup(max_tokens=max_tokens)

    # Prepare a `tf.data.Dataset` that only yields the feature.
    feature_ds = dataset.map(lambda x, y: x[name])

    # Learn the set of possible values and assign them a fixed integer index.
    index.adapt(feature_ds)

    # Encode the integer indices.
    encoder = layers.CategoryEncoding(num_tokens=index.vocabulary_size())

    # Apply multi-hot encoding to the indices. The lambda function captures the
    # layer, so you can use them, or include them in the Keras Functional model later.
    return lambda feature: encoder(index(feature))

In [15]:
test_type_col = train_features['Type']
test_type_layer = get_category_encoding_layer(name='Type',
                                              dataset=train_ds,
                                              dtype='string')
test_type_layer(test_type_col)

<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 1., 0.]], dtype=float32)>

In [16]:
test_age_col = train_features['Age']
test_age_layer = get_category_encoding_layer(name='Age',
                                             dataset=train_ds,
                                             dtype='int64',
                                             max_tokens=5)
test_age_layer(test_age_col)

2025-02-04 13:59:34.294276: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0.]], dtype=float32)>

### Preprocess selected features

Apply the preprocessing helper class defined earlier. Add all the feature inputs to a list.


In [17]:
batch_size = 256
train_ds = df_to_dataset(train, batch_size=batch_size)
val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)
test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)

In [18]:
all_inputs = {}
encoded_features = []

# Numerical features.
for header in ['PhotoAmt', 'Fee']:
    numeric_col = tf.keras.Input(shape=(1,), name=header)
    normalization_layer = get_normalization_layer(header, train_ds)
    encoded_numeric_col = normalization_layer(numeric_col)
    all_inputs[header] = numeric_col
    encoded_features.append(encoded_numeric_col)

In [19]:
age_col = tf.keras.Input(shape=(1,), name='Age', dtype='int64')

encoding_layer = get_category_encoding_layer(name='Age',
                                             dataset=train_ds,
                                             dtype='int64',
                                             max_tokens=5)
encoded_age_col = encoding_layer(age_col)
all_inputs['Age'] = age_col
encoded_features.append(encoded_age_col)

In [20]:
categorical_cols = ['Type', 'Color1', 'Color2', 'Gender', 'MaturitySize',
                    'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Breed1']

for header in categorical_cols:
    categorical_col = tf.keras.Input(shape=(1,), name=header, dtype='string')
    encoding_layer = get_category_encoding_layer(name=header,
                                                dataset=train_ds,
                                                dtype='string',
                                                max_tokens=5)
    encoded_categorical_col = encoding_layer(categorical_col)
    all_inputs[header] = categorical_col
    encoded_features.append(encoded_categorical_col)

2025-02-04 13:59:34.588989: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-02-04 13:59:35.029267: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


### Create, compile, and train model

In [21]:
all_features = tf.keras.layers.concatenate(encoded_features)
x = tf.keras.layers.Dense(32, activation="relu")(all_features)
x = tf.keras.layers.Dropout(0.5)(x)
output = tf.keras.layers.Dense(1)(x)

model = tf.keras.Model(all_inputs, output)

In [22]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=["accuracy"],
              run_eagerly=True)

In [23]:
model.fit(train_ds, epochs=10, validation_data=val_ds)

Epoch 1/10
[1m37/37[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 16ms/step - accuracy: 0.3658 - loss: 0.7746 - val_accuracy: 0.6854 - val_loss: 0.5841
Epoch 2/10
[1m37/37[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 16ms/step - accuracy: 0.6270 - loss: 0.6023 - val_accuracy: 0.7383 - val_loss: 0.5593
Epoch 3/10
[1m37/37[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 16ms/step - accuracy: 0.6650 - loss: 0.5781 - val_accuracy: 0.7392 - val_loss: 0.5442
Epoch 4/10
[1m37/37[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 17ms/step - accuracy: 0.6609 - loss: 0.5744 - val_accuracy: 0.7418 - val_loss: 0.5329
Epoch 5/10
[1m37/37[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 15ms/step - accuracy: 0.6845 - loss: 0.5555 - val_accuracy: 0.7444 - val_loss: 0.5261
Epoch 6/10
[1m37/37[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 15ms/step - accuracy: 0.6910 - loss: 0.5465 - val_accuracy: 0.7513 - val_loss: 0.5198
Epoch 7/10
[1m37/37[0m [32m━━━━

<keras.src.callbacks.history.History at 0x7b15e89123d0>

In [None]:
loss, accuracy = model.evaluate(test_ds)
print("Accuracy", accuracy)

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.7416 - loss: 0.5196 
Accuracy 0.7443674206733704


### Save and reload model

Demonstrate saving the trained model and reloading it for inference.

In [25]:
model.save('models/my_pet_classifier.keras')

In [26]:
reloaded_model = tf.keras.models.load_model('models/my_pet_classifier.keras')

In [None]:
sample = {
    'Type': 'Cat',
    'Age': 3,
    'Breed1': 'Tabby',
    'Gender': 'Male',
    'Color1': 'Black',
    'Color2': 'White',
    'MaturitySize': 'Small',
    'FurLength': 'Short',
    'Vaccinated': 'No',
    'Sterilized': 'No',
    'Health': 'Healthy',
    'Fee': 100,
    'PhotoAmt': 2,
}

input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
predictions = reloaded_model.predict(input_dict)
prob = tf.nn.sigmoid(predictions[0])

print(
    "This particular pet had a %.1f percent probability "
    "of getting adopted." % (100 * prob)
)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
This particular pet had a 83.2 percent probability of getting adopted.


## PySpark

In [28]:
from pyspark.sql.functions import col, struct, pandas_udf
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark import SparkConf
import json
import pandas as pd

Check the cluster environment to handle any platform-specific Spark configurations.

In [29]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_IMAGE_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

#### Create Spark Session

For local standalone clusters, we'll connect to the cluster and create the Spark Session.  
For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).

In [30]:
conf = SparkConf()

if 'spark' not in globals():
    if on_standalone:
        import socket
        
        conda_env = os.environ.get("CONDA_PREFIX")
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
    elif on_dataproc:
        conf.set("spark.executorEnv.TF_GPU_ALLOCATOR", "cuda_malloc_async")

    conf.set("spark.executor.cores", "8")
    conf.set("spark.task.resource.gpu.amount", "0.125")
    conf.set("spark.executor.resource.gpu.amount", "1")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")

conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")
spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

25/02/04 13:59:42 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
25/02/04 13:59:42 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/04 13:59:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


#### Create PySpark DataFrame

In [31]:
df = spark.createDataFrame(dataframe).repartition(8)

In [32]:
data_path = "spark-dl-datasets/petfinder-mini"
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-datasets")
    data_path = "dbfs:/FileStore/" + data_path

df.write.mode("overwrite").parquet(data_path)

                                                                                

#### Load and preprocess DataFrame

In [33]:
df = spark.read.parquet(data_path).cache()

In [34]:
columns = df.columns
print(columns)

['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt', 'target']


In [35]:
# remove label column
columns.remove("target")
print(columns)

['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt']


In [36]:
df.show(5)

+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+
|Type|Age|              Breed1|Gender|Color1|  Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target|
+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+
| Dog|  3|         Mixed Breed|  Male| Black|No Color|       Small|   Medium|  Not Sure|  Not Sure|Healthy|  0|       2|     0|
| Dog|  9|         Mixed Breed|  Male|  Gray|No Color|      Medium|    Short|  Not Sure|        No|Healthy|  0|       4|     1|
| Cat|  4| Domestic Short Hair|  Male| Black|    Gray|      Medium|    Short|  Not Sure|  Not Sure|Healthy|  0|       4|     1|
| Cat|  6| Domestic Short Hair|  Male|Yellow|   White|      Medium|    Short|        No|        No|Healthy|  0|       3|     1|
| Cat|  6|Domestic Medium Hair|  Male|  Gray|No Color|       Small|   Medium|        No|        No|Healt

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):

- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays 
- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function

In [37]:
# get absolute path to model
model_path = "{}/models/my_pet_classifier.keras".format(os.getcwd())

# For cloud environments, copy the model to the distributed file system.
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-models")
    dbfs_model_path = "/dbfs/FileStore/spark-dl-models/my_pet_classifier.keras"
    shutil.copy(model_path, dbfs_model_path)
    model_path = dbfs_model_path
elif on_dataproc:
    # GCS is mounted at /mnt/gcs by the init script
    models_dir = "/mnt/gcs/spark-dl/models"
    os.mkdir(models_dir) if not os.path.exists(models_dir) else None
    gcs_model_path = models_dir + "/my_pet_classifier.keras"
    shutil.copy(model_path, gcs_model_path)
    model_path = gcs_model_path

In [38]:
def predict_batch_fn():
    import tensorflow as tf
    import pandas as pd
    
    # Enable GPU memory growth to avoid CUDA OOM
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)

    model = tf.keras.models.load_model(model_path)

    def predict(t, a, b, g, c1, c2, m, f, v, s, h, fee, p):
        inputs = {
            "Type": t,
            "Age": a,
            "Breed1": b,
            "Gender": g,
            "Color1": c1,
            "Color2": c2,
            "MaturitySize": m,
            "FurLength": f,
            "Vaccinated": v,
            "Sterilized": s,
            "Health": h,
            "Fee": fee,
            "PhotoAmt": p
        }
        # return model.predict(inputs)
        return pd.Series(np.squeeze(model.predict(inputs)))

    return predict

In [39]:
# need to pass the list of columns into the model_udf
classify = predict_batch_udf(predict_batch_fn,
                             return_type=FloatType(),
                             batch_size=100)

In [40]:
%%time
preds = df.withColumn("preds", classify(struct(*columns)))
results = preds.collect()

25/02/04 13:59:47 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


[Stage 5:>                                                          (0 + 8) / 8]

CPU times: user 19.8 ms, sys: 9.3 ms, total: 29.1 ms
Wall time: 4.99 s


                                                                                

In [41]:
%%time
preds = df.withColumn("preds", classify(*columns))
results = preds.collect()

[Stage 6:>                                                          (0 + 8) / 8]

CPU times: user 86.9 ms, sys: 13.7 ms, total: 101 ms
Wall time: 1.56 s


                                                                                

In [42]:
%%time
preds = df.withColumn("preds", classify(*[col(c) for c in columns]))
results = preds.collect()

[Stage 7:>                                                          (0 + 8) / 8]

CPU times: user 16.4 ms, sys: 4.46 ms, total: 20.9 ms
Wall time: 1.52 s


                                                                                

In [43]:
preds.show()

+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+
|Type|Age|              Breed1|Gender|Color1|  Color2|MaturitySize|FurLength|Vaccinated|Sterilized|      Health|Fee|PhotoAmt|target|      preds|
+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+
| Dog|  3|         Mixed Breed|  Male| Black|No Color|       Small|   Medium|  Not Sure|  Not Sure|     Healthy|  0|       2|     0|  0.4963937|
| Dog|  9|         Mixed Breed|  Male|  Gray|No Color|      Medium|    Short|  Not Sure|        No|     Healthy|  0|       4|     1|  0.6780287|
| Cat|  4| Domestic Short Hair|  Male| Black|    Gray|      Medium|    Short|  Not Sure|  Not Sure|     Healthy|  0|       4|     1| 0.58800673|
| Cat|  6| Domestic Short Hair|  Male|Yellow|   White|      Medium|    Short|        No|        No|     Healthy|  0|       3|     

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-server.png" alt="drawing" width="700"/>

In [44]:
from functools import partial

Import the helper class from server_utils.py:

In [45]:
sc.addPyFile("server_utils.py")

from server_utils import TritonServerManager

Define the Triton Server function:

In [46]:
def triton_server(ports, model_path):
    import time
    import signal
    import numpy as np
    import tensorflow as tf
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton, TritonConfig
    from pyspark import TaskContext

    print(f"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.")
    # Enable GPU memory growth
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)

    model = tf.keras.models.load_model(model_path)

    def decode(input_tensor):
        return tf.convert_to_tensor(np.vectorize(lambda x: x.decode('utf-8'))(input_tensor))

    def identity(input_tensor):
        return tf.convert_to_tensor(input_tensor)

    input_transforms = {
        "Type": decode,
        "Age": identity,
        "Breed1": decode,
        "Gender": decode,
        "Color1": decode,
        "Color2": decode,
        "MaturitySize": decode,
        "FurLength": decode,
        "Vaccinated": decode,
        "Sterilized": decode,
        "Health": decode,
        "Fee": identity,
        "PhotoAmt": identity
    }

    @batch
    def _infer_fn(**inputs):
        decoded_inputs = {k: input_transforms[k](v) for k, v in inputs.items()}
        print(f"SERVER: Received batch of size {len(decoded_inputs['Type'])}.")
        return {
            "preds": model.predict(decoded_inputs)
        }

    workspace_path = f"/tmp/triton_{time.strftime('%m_%d_%M_%S')}"
    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])
    with Triton(config=triton_conf, workspace=workspace_path) as triton:
        triton.bind(
            model_name="PetClassifier",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="Type", dtype=np.bytes_, shape=(-1,)),
                Tensor(name="Age", dtype=np.int64, shape=(-1,)),
                Tensor(name="Breed1", dtype=np.bytes_, shape=(-1,)),
                Tensor(name="Gender", dtype=np.bytes_, shape=(-1,)),
                Tensor(name="Color1", dtype=np.bytes_, shape=(-1,)),
                Tensor(name="Color2", dtype=np.bytes_, shape=(-1,)),
                Tensor(name="MaturitySize", dtype=np.bytes_, shape=(-1,)),
                Tensor(name="FurLength", dtype=np.bytes_, shape=(-1,)),
                Tensor(name="Vaccinated", dtype=np.bytes_, shape=(-1,)),
                Tensor(name="Sterilized", dtype=np.bytes_, shape=(-1,)),
                Tensor(name="Health", dtype=np.bytes_, shape=(-1,)),
                Tensor(name="Fee", dtype=np.int64, shape=(-1,)),
                Tensor(name="PhotoAmt", dtype=np.int64, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="preds", dtype=np.float32, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=128,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def _stop_triton(signum, frame):
            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, _stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

#### Start Triton servers

The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:
- Find available ports for HTTP/gRPC/metrics
- Deploy a server on each node via stage-level scheduling
- Gracefully shutdown servers across nodes

In [None]:
model_name = "PetClassifier"
server_manager = TritonServerManager(model_name=model_name, model_path=model_path)

In [None]:
# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}
server_manager.start_servers(triton_server)

2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)
2025-02-07 11:03:44,810 - INFO - Starting 1 servers.


                                                                                

{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}

#### Define client function

Get the hostname -> url mapping from the server manager:

In [None]:
host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url

Define the Triton inference function, which returns a predict function for batch inference through the server:

In [51]:
def triton_fn(model_name, host_to_url):
    import socket
    import numpy as np
    from pytriton.client import ModelClient

    url = host_to_url[socket.gethostname()]
    print(f"CLIENT: Connecting to {model_name} at {url}")

    def infer_batch(t, a, b, g, c1, c2, m, f, v, s, h, fee, p):
        
        def encode(value):
            return np.vectorize(lambda x: x.encode("utf-8"))(value).astype(np.bytes_)

        with ModelClient(url, model_name, inference_timeout_s=240) as client:
            encoded_inputs = {
                "Type": encode(t), 
                "Age": a, 
                "Breed1": encode(b), 
                "Gender": encode(g),
                "Color1": encode(c1),
                "Color2": encode(c2),
                "MaturitySize": encode(m),
                "FurLength": encode(f),
                "Vaccinated": encode(v),
                "Sterilized": encode(s),
                "Health": encode(h),
                "Fee": fee,
                "PhotoAmt": p
            }
            result_data = client.infer_batch(**encoded_inputs)
            return result_data["preds"]
            
    return infer_batch

In [54]:
# need to pass the list of columns into the model_udf
classify = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),
                             input_tensor_shapes=[[1]] * len(columns),
                             return_type=FloatType(),
                             batch_size=64)

#### Load and preprocess DataFrame

In [52]:
df = spark.read.parquet(data_path)

In [53]:
columns = df.columns
# remove label column
columns.remove("target")

#### Run inference

In [55]:
%%time
preds = df.withColumn("preds", classify(struct(*columns)))
results = preds.collect()

[Stage 12:>                                                         (0 + 8) / 8]

CPU times: user 17.3 ms, sys: 7.75 ms, total: 25.1 ms
Wall time: 6.35 s


                                                                                

In [56]:
%%time
preds = df.withColumn("preds", classify(*columns))
results = preds.collect()

                                                                                

CPU times: user 15.2 ms, sys: 4.2 ms, total: 19.4 ms
Wall time: 5.86 s


In [57]:
%%time
preds = df.withColumn("preds", classify(*[col(c) for c in columns]))
results = preds.collect()



CPU times: user 93.4 ms, sys: 3.4 ms, total: 96.8 ms
Wall time: 5.87 s


                                                                                

In [58]:
preds.show()

+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+
|Type|Age|              Breed1|Gender|Color1|  Color2|MaturitySize|FurLength|Vaccinated|Sterilized|      Health|Fee|PhotoAmt|target|      preds|
+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+
| Dog|  3|         Mixed Breed|  Male| Black|No Color|       Small|   Medium|  Not Sure|  Not Sure|     Healthy|  0|       2|     0|  0.4963937|
| Dog|  9|         Mixed Breed|  Male|  Gray|No Color|      Medium|    Short|  Not Sure|        No|     Healthy|  0|       4|     1|  0.6780287|
| Cat|  4| Domestic Short Hair|  Male| Black|    Gray|      Medium|    Short|  Not Sure|  Not Sure|     Healthy|  0|       4|     1| 0.58800673|
| Cat|  6| Domestic Short Hair|  Male|Yellow|   White|      Medium|    Short|        No|        No|     Healthy|  0|       3|     

#### Stop Triton Server on each executor

In [59]:
server_manager.stop_servers()

2025-02-04 14:00:18,330 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)
2025-02-04 14:00:28,520 - INFO - Sucessfully stopped 1 servers.                 


[True]

In [60]:
if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state
    spark.stop()