<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark PyTorch Inference

### Regression

In this notebook, we will train an MLP to perform regression on the California housing dataset, and load it for distributed inference with Spark.  

Based on: https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-create-a-neural-network-for-regression-with-pytorch.md  

We also demonstrate accelerated inference via Torch-TensorRT model compilation.   

In [1]:
import torch
import os
import shutil
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler

In [2]:
os.mkdir('models') if not os.path.exists('models') else None

In [3]:
torch.__version__

'2.4.1+cu121'

### Load Dataset

Each label corresponds to the average house value in units of 100,000, which we'll try to predict using the following features:  
['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup', 'Latitude', 'Longitude']

In [4]:
X, y = fetch_california_housing(return_X_y=True)

In [5]:
class HousingDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, scale_data=True):
        if not torch.is_tensor(X) and not torch.is_tensor(y):
            # Apply scaling if necessary
            if scale_data:
                X = StandardScaler().fit_transform(X)
            self.X = torch.from_numpy(X.astype(np.float32))
            self.y = torch.from_numpy(y.astype(np.float32))

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.y[i]

In [6]:
dataset = HousingDataset(X, y)
trainloader = torch.utils.data.DataLoader(
    dataset, batch_size=10, shuffle=True, num_workers=1)

In [7]:
next(iter(trainloader))

[tensor([[-0.5241, -0.8454, -0.2556,  0.3256,  0.8226,  0.0041,  0.7623, -1.1132],
         [-0.3852, -1.1632, -0.3439, -0.0999, -0.3060, -0.0110, -0.5767,  0.0198],
         [ 0.0802,  0.5054, -0.0854, -0.0261,  0.1170,  0.0557, -0.8062,  0.7485],
         [-0.7215, -0.5276, -0.2073, -0.1787, -0.6115, -0.0134,  1.3756, -0.8686],
         [-0.9009, -0.4481, -0.4374, -0.1855, -0.1788,  0.0477,  1.0760, -0.8537],
         [ 0.6685,  1.3794,  0.2930, -0.1939, -0.8066, -0.0586,  0.9402, -1.1731],
         [-0.7873, -0.9249,  0.1057, -0.0384,  0.3025, -0.0482,  0.1817,  0.2794],
         [ 0.2764,  1.8562, -1.3384, -0.3592, -0.3219,  1.2067,  1.0010, -1.4127],
         [-0.2856, -1.7194, -0.1692,  0.0904,  1.0177, -0.0792, -0.6938,  1.1279],
         [ 1.3093, -0.2097,  0.5685,  0.0202, -0.2477, -0.0660,  0.9121, -1.4127]]),
 tensor([1.8440, 2.4870, 1.5770, 1.0160, 0.6780, 3.1560, 0.7980, 2.2500, 1.2220,
         5.0000])]

### Create and Train Model

In [8]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.layers(x)

In [9]:
# Initialize the MLP
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
mlp = MLP().to(device)

# Define the loss function and optimizer
loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)

Using cuda device


In [10]:
# Run the training loop
for epoch in range(0, 5):  # 5 epochs at maximum

    # Print epoch
    print(f'Starting epoch {epoch+1}')

    # Set current loss value
    current_loss = 0.0

    # Iterate over the DataLoader for training data
    for i, data in enumerate(trainloader, 0):

        # Get and prepare inputs
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
        targets = targets.reshape((targets.shape[0], 1))

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = mlp(inputs)

        # Compute loss
        loss = loss_function(outputs, targets)

        # Perform backward pass
        loss.backward()

        # Perform optimization
        optimizer.step()

        # Print statistics
        current_loss += loss.item()
        if i % 200 == 0:
            print('Loss after mini-batch %5d: %.3f' %
                  (i + 1, current_loss / 500))
            current_loss = 0.0

# Process is complete.
print('Training process has finished.')

Starting epoch 1
Loss after mini-batch     1: 0.003
Loss after mini-batch   201: 0.709
Loss after mini-batch   401: 0.505
Loss after mini-batch   601: 0.379
Loss after mini-batch   801: 0.305
Loss after mini-batch  1001: 0.269
Loss after mini-batch  1201: 0.257
Loss after mini-batch  1401: 0.227
Loss after mini-batch  1601: 0.214
Loss after mini-batch  1801: 0.223
Loss after mini-batch  2001: 0.223
Starting epoch 2
Loss after mini-batch     1: 0.001
Loss after mini-batch   201: 0.214
Loss after mini-batch   401: 0.206
Loss after mini-batch   601: 0.200
Loss after mini-batch   801: 0.199
Loss after mini-batch  1001: 0.203
Loss after mini-batch  1201: 0.197
Loss after mini-batch  1401: 0.197
Loss after mini-batch  1601: 0.197
Loss after mini-batch  1801: 0.188
Loss after mini-batch  2001: 0.193
Starting epoch 3
Loss after mini-batch     1: 0.001
Loss after mini-batch   201: 0.182
Loss after mini-batch   401: 0.186
Loss after mini-batch   601: 0.195
Loss after mini-batch   801: 0.195
Loss

### Save Model State Dict
This saves the serialized object to disk using pickle.

In [11]:
torch.save(mlp.state_dict(), "models/housing_model.pt")
print("Saved PyTorch Model State to models/housing_model.pt")

Saved PyTorch Model State to models/housing_model.pt


### Save Model as TorchScript
This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python). 

In [12]:
scripted = torch.jit.script(mlp)
scripted.save("models/ts_housing_model.pt")
print("Saved TorchScript Model to models/ts_housing_model.pt")

Saved TorchScript Model to models/ts_housing_model.pt


### Load and Test from Model State

In [13]:
loaded_mlp = MLP().to(device)
loaded_mlp.load_state_dict(torch.load("models/housing_model.pt", weights_only=True))

<All keys matched successfully>

In [14]:
testX, testY = next(iter(trainloader))

In [15]:
print("Predictions:")
loaded_mlp(testX.to(device))

Predictions:


tensor([[2.0423],
        [1.9258],
        [2.8864],
        [3.2128],
        [2.8639],
        [1.0359],
        [2.8652],
        [1.5528],
        [1.7592],
        [0.7497]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [16]:
print("Labels:")
testY

Labels:


tensor([2.2030, 2.1590, 1.9400, 3.4310, 2.4480, 1.1410, 2.8780, 0.8310, 1.6530,
        1.2380])

### Load and Test from TorchScript

In [17]:
scripted_mlp = torch.jit.load("models/ts_housing_model.pt")

In [18]:
print("Predictions:")
scripted_mlp(testX.to(device)).flatten()

Predictions:


tensor([2.0423, 1.9258, 2.8864, 3.2128, 2.8639, 1.0359, 2.8652, 1.5528, 1.7592,
        0.7497], device='cuda:0', grad_fn=<ViewBackward0>)

### Compile using the Torch JIT Compiler
This leverages the [Torch-TensorRT inference compiler](https://pytorch.org/TensorRT/) for accelerated inference on GPUs using the `torch.compile` JIT interface under the hood. The compiler stack returns a [boxed-function](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/) that triggers compilation on the first call.  

Modules compiled in this fashion are [not serializable with pickle](https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089), so we cannot send the compiled model directly to Spark.  

(You may see a warning about modelopt quantization. This is safe to ignore, as [implicit quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#intro-quantization) is deprecated in the latest TensorRT. See [this link](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq.html) for a guide to explicit quantization.)

In [19]:
import torch_tensorrt as trt
import time



In [20]:
# Optional: set the filename for the TensorRT timing cache
timestamp = time.time()
timing_cache = f"/tmp/timing_cache-{timestamp}.bin"
with open(timing_cache, "wb") as f:
    pass

In [21]:
inputs_bs1 = torch.randn((10, 8), dtype=torch.float).to("cuda")
# This indicates dimension 0 of inputs_bs1 is dynamic with a range of values [1, 50]. No recompilation will happen when the batch size changes.
torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=50)
trt_model = trt.compile(
    loaded_mlp,
    ir="torch_compile",
    inputs=inputs_bs1,
    enabled_precisions={torch.float},
    timing_cache_path=timing_cache,
)

In [22]:
stream = torch.cuda.Stream()
with torch.no_grad(), torch.cuda.stream(stream):
    testX = testX.to(device)
    print("Predictions:")
    print(trt_model(testX))

INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Device not specified, using Torch default current device - cuda:0. If this is incorrect, please specify an input device, via the device keyword.
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=53

INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 581, GPU 2466 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +1635, GPU +288, now: CPU 2363, GPU 2754 (MiB)


Predictions:


  if input_val.dynamic_range is not None and dyn_range_fn is not None:

INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.008361
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 1 inputs and 1 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 22240
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 10 steps to complete.
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.147039ms to assign 4 blocks to 10 nodes requiring 7168 bytes.
INFO:torch_tensorrt [TensorRT Conversion Conte

tensor([[2.0423],
        [1.9258],
        [2.8864],
        [3.2128],
        [2.8639],
        [1.0359],
        [2.8652],
        [1.5528],
        [1.7592],
        [0.7497]], device='cuda:0')


### Compile using the Torch-TensorRT AOT Compiler
Alternatively, use the Torch-TensorRT Dynamo backend for Ahead-of-Time (AOT) compilation to eagerly optimize the model in an explicit compilation phase. We first export the model to produce a traced graph representing the Tensor computation in an AOT fashion, which produces a `ExportedProgram` object which can be [serialized and reloaded](https://pytorch.org/TensorRT/user_guide/saving_models.html). We can then compile this IR using the Torch-TensorRT AOT compiler for inference.   

[Read the docs](https://pytorch.org/TensorRT/user_guide/torch_tensorrt_explained.html) for more information on JIT vs AOT compilation.

In [23]:
example_inputs = (torch.randn((10, 8), dtype=torch.float).to("cuda"),)

# Mark dim 1 (batch size) as dynamic
batch = torch.export.Dim("batch", min=1, max=64)
# Produce traced graph in ExportedProgram format
exp_program = torch.export.export(loaded_mlp, args=example_inputs, dynamic_shapes={"x": {0: batch}})
# Compile the traced graph to produce an optimized module
trt_gm = trt.dynamo.compile(exp_program,
                            tuple(example_inputs),
                            enabled_precisions={torch.float},
                            timing_cache_path=timing_cache,
                            workspace_size=1<<30)

INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=1073741824, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache-1738007491.6462495.bin')

INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemU

In [24]:
print(type(exp_program))
print(type(trt_gm))

<class 'torch.export.exported_program.ExportedProgram'>
<class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>


In [25]:
stream = torch.cuda.Stream()
with torch.no_grad(), torch.cuda.stream(stream):
    print("Predictions:")
    testX = testX.to(device)
    print(trt_gm(testX))

Predictions:
tensor([[2.0425],
        [1.9257],
        [2.8869],
        [3.2127],
        [2.8640],
        [1.0362],
        [2.8658],
        [1.5528],
        [1.7586],
        [0.7496]], device='cuda:0')


We can run the optimized module with a few different batch sizes (without recompilation!):

In [26]:
inputs = (torch.randn((10, 8), dtype=torch.float).cuda(),)
inputs_bs1 = (torch.randn((1, 8), dtype=torch.float).cuda(),)
inputs_bs50 = (torch.randn((50, 8), dtype=torch.float).cuda(),)

stream = torch.cuda.Stream()
with torch.no_grad(), torch.cuda.stream(stream):
    print("Output shapes:")
    print(trt_gm(*inputs).shape)
    print(trt_gm(*inputs_bs1).shape)
    print(trt_gm(*inputs_bs50).shape)

Output shapes:
torch.Size([10, 1])
torch.Size([1, 1])
torch.Size([50, 1])


We can serialize the ExportedProgram (a traced graph representing the model's forward function) using `torch.export.save` to be recompiled at a later date.

In [27]:
torch.export.save(exp_program, "models/trt_housing_model.ep")
print("Saved ExportedProgram to models/trt_housing_model.ep")

Saved ExportedProgram to models/trt_housing_model.ep


## PySpark

In [28]:
from pyspark.sql.functions import col, struct, pandas_udf, array
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark import SparkConf
import json
import pandas as pd

  from typing.io import BinaryIO  # type: ignore[import]



Check the cluster environment to handle any platform-specific Spark configurations.

In [29]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_IMAGE_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

#### Create Spark Session

For local standalone clusters, we'll connect to the cluster and create the Spark Session.  
For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).

In [30]:
conf = SparkConf()

if 'spark' not in globals():
    if on_standalone:
        import socket
        conda_env = os.environ.get("CONDA_PREFIX")
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
        # Point PyTriton to correct libpython3.11.so:
        conf.set("spark.executorEnv.LD_LIBRARY_PATH", f"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH")
    elif on_dataproc:
        # Point PyTriton to correct libpython3.11.so:
        conda_lib_path="/opt/conda/miniconda3/lib"
        conf.set("spark.executorEnv.LD_LIBRARY_PATH", f"{conda_lib_path}:$LD_LIBRARY_PATH")
        conf.set("spark.executor.instances", "4") # dataproc defaults to 2

    conf.set("spark.executor.cores", "8")
    conf.set("spark.task.resource.gpu.amount", "0.125")
    conf.set("spark.executor.resource.gpu.amount", "1")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")
    
spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

25/01/27 19:51:37 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
25/01/27 19:51:37 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/27 19:51:37 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/01/27 19:51:38 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


### Create Spark DataFrame from Pandas DataFrame

In [31]:
housing = fetch_california_housing()

In [32]:
X = StandardScaler().fit_transform(housing.data.astype(np.float32))

In [33]:
pdf = pd.DataFrame(X, columns=housing.feature_names)

In [34]:
schema = StructType([
    StructField("MedInc",FloatType(),True),
    StructField("HouseAge",FloatType(),True),
    StructField("AveRooms",FloatType(),True),
    StructField("AveBedrms",FloatType(),True),
    StructField("Population",FloatType(),True),
    StructField("AveOccup",FloatType(),True),
    StructField("Latitude",FloatType(),True),
    StructField("Longitude",FloatType(),True)
])

df = spark.createDataFrame(pdf, schema=schema).repartition(8)
df.show(truncate=12)

  elif is_categorical_dtype(s.dtype):

[Stage 0:>                                                          (0 + 8) / 8]

+------------+------------+-----------+------------+-----------+------------+----------+------------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|
+------------+------------+-----------+------------+-----------+------------+----------+------------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|
|  0.59671736|   0.5848523| 0.19346413|  -0.1371872|-0.19645879| 0.009964322|0.968

                                                                                

In [35]:
df.schema

StructType([StructField('MedInc', FloatType(), True), StructField('HouseAge', FloatType(), True), StructField('AveRooms', FloatType(), True), StructField('AveBedrms', FloatType(), True), StructField('Population', FloatType(), True), StructField('AveOccup', FloatType(), True), StructField('Latitude', FloatType(), True), StructField('Longitude', FloatType(), True)])

In [36]:
data_path = "spark-dl-datasets/california_housing"
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-datasets")
    data_path = "dbfs:/FileStore/" + data_path

df.write.mode("overwrite").parquet(data_path)

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):

- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays 
- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function

In [37]:
df = spark.read.parquet(data_path)

In [38]:
columns = df.columns

In [39]:
# get absolute path to model
model_path = "{}/models/trt_housing_model.ep".format(os.getcwd())

# For cloud environments, copy the model to the distributed file system.
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-models")
    dbfs_model_path = "/dbfs/FileStore/spark-dl-models/trt_housing_model.ep"
    shutil.copy(model_path, dbfs_model_path)
    model_path = dbfs_model_path
elif on_dataproc:
    # GCS is mounted at /mnt/gcs by the init script
    models_dir = "/mnt/gcs/spark-dl/models"
    os.mkdir(models_dir) if not os.path.exists(models_dir) else None
    gcs_model_path = models_dir + "/trt_housing_model.ep"
    shutil.copy(model_path, gcs_model_path)
    model_path = gcs_model_path

For inference on Spark, we'll load the ExportedProgram and compile the model with the Torch-TensorRT AOT compiler and cache on the executor. 

In [40]:
# A resource warning may occur due to unclosed file descriptors used by TensorRT across multiple PySpark daemon processes.
# These can be safely ignored as the resources will be cleaned up when the worker processes terminate.

import warnings
warnings.simplefilter("ignore", ResourceWarning)

In [41]:
def predict_batch_fn():
    import torch
    import torch_tensorrt as trt

    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device != "cuda":
        raise ValueError("This function uses the TensorRT model which requires a GPU device")

    example_inputs = (torch.randn((50, 8), dtype=torch.float).to("cuda"),)
    exp_program = torch.export.load(model_path)
    trt_gm = trt.dynamo.compile(exp_program,
                            tuple(example_inputs),
                            enabled_precisions={torch.float},
                            timing_cache_path=timing_cache,
                            workspace_size=1<<30)

    print("Model compiled.")
    
    def predict(inputs):
        stream = torch.cuda.Stream()
        with torch.no_grad(), torch.cuda.stream(stream), trt.logging.errors():
            print(f"Predict {inputs.shape}")
            torch_inputs = torch.from_numpy(inputs).to(device)
            outputs = trt_gm(torch_inputs) # .flatten()
            return outputs.detach().cpu().numpy()

    return predict

In [42]:
regress = predict_batch_udf(predict_batch_fn,
                             return_type=FloatType(),
                             input_tensor_shapes=[[8]],
                             batch_size=50)

In [43]:
%%time
preds = df.withColumn("preds", regress(struct(*columns)))
results = preds.collect()



CPU times: user 29.2 ms, sys: 4.38 ms, total: 33.6 ms
Wall time: 8.23 s


                                                                                

In [44]:
%%time
preds = df.withColumn("preds", regress(array(*columns)))
results = preds.collect()

CPU times: user 30.7 ms, sys: 6.15 ms, total: 36.8 ms
Wall time: 309 ms


In [45]:
%%time
preds = df.withColumn("preds", regress(array(*columns)))
results = preds.collect()

CPU times: user 27.3 ms, sys: 4.6 ms, total: 31.9 ms
Wall time: 276 ms


In [46]:
preds.show()

+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|    preds|
+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|1.4441836|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|1.7245855|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|1.3524103|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|2.3009148|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614| 1.272077|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|0.6

In [47]:
# This will clear the engine cache (containing previously compiled TensorRT engines) and reset the CUDA Context.
torch._dynamo.reset()

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-pytriton.png" alt="drawing" width="700"/>

In [48]:
from functools import partial

Import the utility functions from pytriton_utils.py:

In [49]:
sc.addPyFile("pytriton_utils.py")

from pytriton_utils import (
    use_stage_level_scheduling,
    find_ports,
    start_triton,
    stop_triton
)

Define the Triton Server function:

In [50]:
def triton_server(ports, model_path):
    import time
    import signal
    import numpy as np
    import torch
    import torch_tensorrt as trt
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton, TritonConfig
    from pyspark import TaskContext

    print(f"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    exp_program = torch.export.load(model_path)
    example_inputs = (torch.randn((50, 8), dtype=torch.float).to("cuda"),)
    trt_gm = trt.dynamo.compile(exp_program,
                            tuple(example_inputs),
                            enabled_precisions={torch.float},
                            workspace_size=1<<30)

    print("SERVER: Compiled model.")

    @batch
    def _infer_fn(**inputs):
        features = inputs["features"]
        if len(inputs["features"]) != 1:
            features = np.squeeze(features)
        stream = torch.cuda.Stream()
        with torch.no_grad(), torch.cuda.stream(stream):
            torch_inputs = torch.from_numpy(features).to(device)
            outputs = trt_gm(torch_inputs)
            return {
                "preds": outputs.cpu().numpy(),
            }

    workspace_path = f"/tmp/triton_{time.strftime('%m_%d_%M_%S')}"
    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])
    with Triton(config=triton_conf, workspace=workspace_path) as triton:
        triton.bind(
            model_name="HousingModel",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="features", dtype=np.float32, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="preds", dtype=np.float32, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=50,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def _stop_triton(signum, frame):
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, _stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

#### Start Triton servers

**Specify the number of nodes in the cluster.**  
Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. 

In [51]:
# Change based on cluster setup
num_nodes = 1 if on_standalone else 4

To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU.  

In [52]:
sc = spark.sparkContext
nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
nodeRDD = use_stage_level_scheduling(spark, nodeRDD)

Requesting stage-level resources: (cores=5, gpu=1.0)


Triton occupies ports for HTTP requests, GRPC requests, and the metrics service.

In [53]:
model_name = "HousingModel"
ports = find_ports()
assert len(ports) == 3
print(f"Using ports {ports}")

Using ports [7000, 7001, 7002]


In [54]:
pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(triton_server_fn=triton_server,
                                                              ports=ports,
                                                              model_name=model_name,
                                                              model_path=model_path)).collectAsMap()
print("Triton Server PIDs:\n", json.dumps(pids, indent=4))

[Stage 11:>                                                         (0 + 1) / 1]

Triton Server PIDs:
 {
    "cb4ae00-lcedt": 2964321
}


                                                                                

#### Define client function

In [55]:
url = f"http://localhost:{ports[0]}"

In [56]:
def triton_fn(url, model_name):
    from pytriton.client import ModelClient

    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, inference_timeout_s=240) as client:
            result_data = client.infer_batch(inputs)
            return result_data["preds"]
        
    return infer_batch

### Run Inference

In [57]:
df = spark.read.parquet(data_path)

In [58]:
columns = df.columns

In [59]:
regress = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name),
                               input_tensor_shapes=[[8]],
                               return_type=FloatType(),
                               batch_size=50)

In [60]:
%%time
# first pass caches model/fn
predictions = df.withColumn("preds", regress(struct(*columns)))
preds = predictions.collect()



CPU times: user 21.1 ms, sys: 8.1 ms, total: 29.2 ms
Wall time: 1.01 s


                                                                                

In [61]:
%%time
predictions = df.withColumn("preds", regress(array(*columns)))
preds = predictions.collect()

[Stage 14:>                                                         (0 + 8) / 8]

CPU times: user 23.7 ms, sys: 4.15 ms, total: 27.8 ms
Wall time: 973 ms


                                                                                

In [62]:
%%time
predictions = df.withColumn("preds", regress(array(*columns)))
preds = predictions.collect()

[Stage 15:>                                                         (0 + 8) / 8]

CPU times: user 34.3 ms, sys: 6.47 ms, total: 40.7 ms
Wall time: 1.49 s


                                                                                

In [63]:
predictions.show()

+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|     preds|
+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053| 1.4441836|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742| 1.7245854|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378| 1.3524104|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|  2.300915|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614| 1.2720771|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.008

#### Stop Triton Server on each executor

In [64]:
shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
shutdownRDD = use_stage_level_scheduling(spark, shutdownRDD)
shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()

Requesting stage-level resources: (cores=5, gpu=1.0)


                                                                                

[True]

In [65]:
if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state
    spark.stop()