<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark PyTorch Inference

### Regression

This notebook demonstrates distributed inference to perform regression on the California housing dataset.  

Based on: https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-create-a-neural-network-for-regression-with-pytorch.md  

For the first MLP (array inputs) we'll also demonstrate accelerated inference on GPU with Torch-TensorRT. 

In [1]:
import torch
import os
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler

In [2]:
torch.__version__

'2.4.1+cu121'

In [3]:
os.mkdir('models') if not os.path.exists('models') else None

### Load Dataset

Each label corresponds to the average house value in units of 100,000. We'll try to predict this value from the features:  
['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup', 'Latitude', 'Longitude']

In [4]:
X, y = fetch_california_housing(return_X_y=True)

In [5]:
class HousingDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, scale_data=True):
        if not torch.is_tensor(X) and not torch.is_tensor(y):
            # Apply scaling if necessary
            if scale_data:
                X = StandardScaler().fit_transform(X)
            self.X = torch.from_numpy(X.astype(np.float32))
            self.y = torch.from_numpy(y.astype(np.float32))

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.y[i]

In [6]:
dataset = HousingDataset(X, y)
trainloader = torch.utils.data.DataLoader(
    dataset, batch_size=10, shuffle=True, num_workers=1)

In [7]:
next(iter(trainloader))

[tensor([[-0.3431,  0.1876, -0.3708, -0.0592,  0.5992,  0.0204,  1.0947, -1.3877],
         [-0.3383,  0.4259, -0.0961, -0.0847,  0.8261, -0.0221, -0.5580,  0.1396],
         [-0.4532,  0.2670, -0.3209, -0.0440, -0.7793, -0.0479, -0.6938,  0.7336],
         [-0.0470,  0.1081, -0.3122, -0.2186,  0.4985, -0.0314, -0.7078,  0.6986],
         [ 1.8892,  0.1876,  1.1507, -0.0439, -0.0384,  0.0102, -0.7874,  0.8184],
         [-0.1622, -0.9249, -0.2637, -0.0563, -0.6742, -0.1518, -0.9466,  0.9133],
         [ 2.0113,  0.2670,  0.7077, -0.1957, -0.0287, -0.0306,  1.0525, -1.2979],
         [ 0.9382, -0.2892,  0.6525, -0.1648,  0.2415,  0.0423, -0.6376,  0.4191],
         [ 0.6282,  0.1876, -0.0458, -0.2200, -0.5320,  0.0056, -0.8249,  0.6337],
         [ 0.3681,  0.5849,  0.0684, -0.2811, -0.6309,  0.0169, -0.8015,  0.7386]]),
 tensor([1.1380, 2.4310, 1.9900, 2.3200, 5.0000, 1.8080, 4.0180, 2.2580, 2.3490,
         1.7080])]

### Create and Train Model

In [8]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.layers(x)

In [9]:
# Initialize the MLP
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
mlp = MLP().to(device)

# Define the loss function and optimizer
loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)

Using cuda device


In [10]:
# Run the training loop
for epoch in range(0, 5):  # 5 epochs at maximum

    # Print epoch
    print(f'Starting epoch {epoch+1}')

    # Set current loss value
    current_loss = 0.0

    # Iterate over the DataLoader for training data
    for i, data in enumerate(trainloader, 0):

        # Get and prepare inputs
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
        targets = targets.reshape((targets.shape[0], 1))

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = mlp(inputs)

        # Compute loss
        loss = loss_function(outputs, targets)

        # Perform backward pass
        loss.backward()

        # Perform optimization
        optimizer.step()

        # Print statistics
        current_loss += loss.item()
        if i % 200 == 0:
            print('Loss after mini-batch %5d: %.3f' %
                  (i + 1, current_loss / 500))
            current_loss = 0.0

# Process is complete.
print('Training process has finished.')

Starting epoch 1
Loss after mini-batch     1: 0.005
Loss after mini-batch   201: 0.780
Loss after mini-batch   401: 0.529
Loss after mini-batch   601: 0.367
Loss after mini-batch   801: 0.313
Loss after mini-batch  1001: 0.279
Loss after mini-batch  1201: 0.268
Loss after mini-batch  1401: 0.248
Loss after mini-batch  1601: 0.250
Loss after mini-batch  1801: 0.239
Loss after mini-batch  2001: 0.231
Starting epoch 2
Loss after mini-batch     1: 0.001
Loss after mini-batch   201: 0.243
Loss after mini-batch   401: 0.223
Loss after mini-batch   601: 0.222
Loss after mini-batch   801: 0.213
Loss after mini-batch  1001: 0.219
Loss after mini-batch  1201: 0.212
Loss after mini-batch  1401: 0.212
Loss after mini-batch  1601: 0.204
Loss after mini-batch  1801: 0.206
Loss after mini-batch  2001: 0.194
Starting epoch 3
Loss after mini-batch     1: 0.001
Loss after mini-batch   201: 0.205
Loss after mini-batch   401: 0.201
Loss after mini-batch   601: 0.189
Loss after mini-batch   801: 0.193
Loss

### Save Model State Dict
This saves the serialized object to disk using pickle.

In [11]:
torch.save(mlp.state_dict(), "models/housing_model.pt")
print("Saved PyTorch Model State to models/housing_model.pt")

Saved PyTorch Model State to models/housing_model.pt


### Save Model as TorchScript
This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python). 

In [12]:
scripted = torch.jit.script(mlp)
scripted.save("models/ts_housing_model.pt")
print("Saved TorchScript Model to models/ts_housing_model.pt")

Saved TorchScript Model to models/ts_housing_model.pt


### Load and Test from Model State

In [13]:
loaded_mlp = MLP().to(device)
loaded_mlp.load_state_dict(torch.load("models/housing_model.pt", weights_only=True))

<All keys matched successfully>

In [14]:
testX, testY = next(iter(trainloader))

In [15]:
loaded_mlp(testX.to(device))

tensor([[2.5311],
        [3.1507],
        [2.5243],
        [1.0335],
        [2.5771],
        [2.8876],
        [0.9761],
        [4.1654],
        [1.9997],
        [1.0379]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [16]:
testY

tensor([2.4030, 2.9640, 3.6610, 1.6700, 2.0510, 2.8910, 1.0400, 3.6290, 1.0730,
        1.3080])

### Load and Test from TorchScript

In [17]:
scripted_mlp = torch.jit.load("models/ts_housing_model.pt")

In [18]:
scripted_mlp(testX.to(device)).flatten()

tensor([2.5311, 3.1507, 2.5243, 1.0335, 2.5771, 2.8876, 0.9761, 4.1654, 1.9997,
        1.0379], device='cuda:0', grad_fn=<ViewBackward0>)

### Compile using the Torch JIT Compiler
This leverages the [Torch-TensorRT inference compiler](https://pytorch.org/TensorRT/) for accelerated inference on GPUs using the `torch.compile` JIT interface under the hood. The compiler stack returns a [boxed-function](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/) that triggers compilation on the first call.  

Modules compiled in this fashion are [not serializable with pickle](https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089), so we cannot send the compiled model directly to Spark. Instead, we will recompile and cache the model on the executor. 

(You may see a warning about modelopt quantization. This is safe to ignore, as [implicit quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#intro-quantization) is deprecated in the latest TensorRT. See [this link](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq.html) for a guide to explicit quantization.)

In [19]:
import torch_tensorrt as trt
import time



In [20]:
# Optional: set the filename for the TensorRT timing cache
timestamp = time.time()
timing_cache = f"/tmp/timing_cache-{timestamp}.bin"
with open(timing_cache, "wb") as f:
    pass

In [21]:
inputs_bs1 = torch.randn((10, 8), dtype=torch.float).to("cuda")
# This indicates dimension 0 of inputs_bs1 is dynamic with a range of values [1, 50]. No recompilation will happen when the batch size changes.
torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=50)
trt_model = trt.compile(
    loaded_mlp,
    ir="torch_compile",
    inputs=inputs_bs1,
    enabled_precisions={torch.float},
    timing_cache_path=timing_cache,
)

In [22]:
stream = torch.cuda.Stream()
with torch.no_grad(), torch.cuda.stream(stream):
    testX = testX.to(device)
    print(trt_model(testX))

INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Device not specified, using Torch default current device - cuda:0. If this is incorrect, please specify an input device, via the device keyword.
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=53

INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +1, GPU +0, now: CPU 582, GPU 786 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +1635, GPU +288, now: CPU 2364, GPU 1074 (MiB)
  if input_val.dynamic_range is not None and dyn_range_fn is not None:

INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.003524
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 1 inputs and 1 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 22240
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started

tensor([[2.5311],
        [3.1507],
        [2.5243],
        [1.0335],
        [2.5771],
        [2.8876],
        [0.9761],
        [4.1654],
        [1.9997],
        [1.0379]], device='cuda:0')


### Compile using the Torch-TensorRT AOT Compiler
Alternatively, use the Torch-TensorRT Dynamo backend for Ahead-of-Time (AOT) compilation to eagerly optimize the model in an explicit compilation phase. We first export the model to produce a traced graph representing the Tensor computation in an AOT fashion, which produces a `ExportedProgram` object which can be [serialized and reloaded](https://pytorch.org/TensorRT/user_guide/saving_models.html). We can then compile this IR using the Torch-TensorRT AOT compiler for inference.   

[Read the docs](https://pytorch.org/TensorRT/user_guide/torch_tensorrt_explained.html) for more information on JIT vs AOT compilation.

In [23]:
# Preparing the inputs for batch_size = 50. 
inputs = (torch.randn((10, 8), dtype=torch.float).cuda(),)

# Produce traced graph in the ExportedProgram format
exp_program = trt.dynamo.trace(loaded_mlp, inputs)
# Compile the traced graph to produce an optimized module
trt_gm = trt.dynamo.compile(exp_program,
                            inputs=inputs,
                            timing_cache_path=timing_cache)

INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache-1733767909.5417278.bin')

INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChang

In [24]:
stream = torch.cuda.Stream()
with torch.no_grad(), torch.cuda.stream(stream):
    testX = testX.to(device)
    print(trt_gm(testX))

tensor([[2.5311],
        [3.1507],
        [2.5243],
        [1.0335],
        [2.5771],
        [2.8876],
        [0.9761],
        [4.1654],
        [1.9997],
        [1.0379]], device='cuda:0')


We can save the compiled model using `torch_tensorrt.save`. Unfortunately, serializing the model to be reloaded at a later date currently only supports *static inputs* ([link to issue](https://github.com/pytorch/pytorch/issues/137365)).

In [25]:
with torch.cuda.stream(stream):
    trt.save(trt_gm, "models/trt_model_aot.ep", inputs=[torch.randn((10, 8), dtype=torch.float).to("cuda")])
    print("Saved AOT compiled TensorRT model to models/trt_model_aot.ep")

  engine_node = gm.graph.get_attr(engine_name)




Saved AOT compiled TensorRT model to models/trt_model_aot.ep


## PySpark

In [26]:
from pyspark.sql.functions import col, struct, pandas_udf, array
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark import SparkConf
import json
import pandas as pd

  from typing.io import BinaryIO  # type: ignore[import]



Check the cluster environment to handle any platform-specific Spark configurations.

In [None]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

In [None]:
conf = SparkConf()

if on_standalone:
    conda_env = os.environ.get("CONDA_PREFIX")
    # Point PyTriton to correct libpython3.11.so:
    conf.set("spark.executorEnv.LD_LIBRARY_PATH", f"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH")
    if 'spark' not in globals():
        import socket
        # If Spark was not started with Jupyter, attach to local standalone
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
elif on_dataproc:
    # Point PyTriton to correct libpython3.11.so:
    conda_lib_path="/opt/conda/miniconda3/lib"
    conf.set("spark.executorEnv.LD_LIBRARY_PATH", f"{conda_lib_path}:$LD_LIBRARY_PATH") 

conf.set("spark.driver.memory", "8g")
conf.set("spark.executor.memory", "8g")
conf.set("spark.executor.cores", "8")
conf.set("spark.task.resource.gpu.amount", "0.125")
conf.set("spark.executor.resource.gpu.amount", "1")
conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")

spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

24/12/09 18:11:53 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
24/12/09 18:11:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/09 18:11:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


### Create Spark DataFrame from Pandas DataFrame

In [29]:
housing = fetch_california_housing()

In [30]:
X = StandardScaler().fit_transform(housing.data.astype(np.float32))

In [31]:
pdf = pd.DataFrame(X, columns=housing.feature_names)

In [32]:
schema = StructType([
    StructField("MedInc",FloatType(),True),
    StructField("HouseAge",FloatType(),True),
    StructField("AveRooms",FloatType(),True),
    StructField("AveBedrms",FloatType(),True),
    StructField("Population",FloatType(),True),
    StructField("AveOccup",FloatType(),True),
    StructField("Latitude",FloatType(),True),
    StructField("Longitude",FloatType(),True)
])

df = spark.createDataFrame(pdf, schema=schema).repartition(8)
df.show(truncate=12)

  elif is_categorical_dtype(s.dtype):



+------------+------------+-----------+------------+-----------+------------+----------+------------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|
+------------+------------+-----------+------------+-----------+------------+----------+------------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|
|  0.59671736|   0.5848523| 0.19346413|  -0.1371872|-0.19645879| 0.009964322|0.968

In [33]:
df.schema

StructType([StructField('MedInc', FloatType(), True), StructField('HouseAge', FloatType(), True), StructField('AveRooms', FloatType(), True), StructField('AveBedrms', FloatType(), True), StructField('Population', FloatType(), True), StructField('AveOccup', FloatType(), True), StructField('Latitude', FloatType(), True), StructField('Longitude', FloatType(), True)])

In [34]:
df.write.mode("overwrite").parquet("datasets/california_housing")

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):

- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays 
- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function

In [35]:
df = spark.read.parquet("datasets/california_housing")

In [36]:
columns = df.columns

In [None]:
# get absolute path to model
model_path = "{}/models/housing_model.pt".format(os.getcwd())

For inference on Spark, we'll compile the model with the Torch-TensorRT AOT compiler and cache on the executor. We can specify dynamic batch sizes before compilation to [optimize across multiple input shapes](https://pytorch.org/TensorRT/user_guide/dynamic_shapes.html).

In [38]:
import warnings
warnings.simplefilter("ignore", ResourceWarning)

In [None]:
def predict_batch_fn():
    import torch
    import torch_tensorrt as trt

    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device != "cuda":
        raise ValueError("This function uses the TensorRT model which requires a GPU device")

    # Define model
    class MLP(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.Sequential(
                nn.Linear(8, 64),
                nn.ReLU(),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, 1)
            )

        def forward(self, x):
            return self.layers(x)

    model = MLP().to(device)
    model.load_state_dict(torch.load(model_path, weights_only=True))

    # Preparing the inputs for dynamic batch sizing.
    inputs = [trt.Input(min_shape=(20, 8), 
                        opt_shape=(50, 8), 
                        max_shape=(64, 8), 
                        dtype=torch.float32)]

    # Trace the computation graph and compile to produce an optimized module
    trt_gm = trt.compile(model, ir="dynamo", inputs=inputs, require_full_compilation=True)
    
    def predict(inputs):
        stream = torch.cuda.Stream()
        with torch.no_grad(), torch.cuda.stream(stream), trt.logging.errors():
            torch_inputs = torch.from_numpy(inputs).to(device)
            outputs = trt_gm(torch_inputs) # .flatten()
            return outputs.detach().cpu().numpy()

    return predict

In [40]:
regress = predict_batch_udf(predict_batch_fn,
                             return_type=FloatType(),
                             input_tensor_shapes=[[8]],
                             batch_size=50)

In [41]:
%%time
preds = df.withColumn("preds", regress(struct(*columns)))
results = preds.collect()

                                                                                

CPU times: user 147 ms, sys: 17.9 ms, total: 165 ms
Wall time: 8.37 s


In [42]:
%%time
preds = df.withColumn("preds", regress(array(*columns)))
results = preds.collect()

CPU times: user 25.7 ms, sys: 7.19 ms, total: 32.9 ms
Wall time: 241 ms


In [43]:
%%time
preds = df.withColumn("preds", regress(array(*columns)))
results = preds.collect()

CPU times: user 28.5 ms, sys: 3.82 ms, total: 32.3 ms
Wall time: 278 ms


In [44]:
preds.show()

+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|     preds|
+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053| 1.4012802|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742| 1.8582058|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378| 1.4144654|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787| 2.4578812|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614| 1.1153543|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.008

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-pytriton.png" alt="drawing" width="700"/>

In [45]:
from functools import partial

In [None]:
def triton_server(model_path):
    import signal
    import numpy as np
    import torch
    from torch import nn
    import torch_tensorrt as trt
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton
    from pyspark import TaskContext

    print(f"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Define model
    class MLP(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.Sequential(
                nn.Linear(8, 64),
                nn.ReLU(),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, 1)
            )

        def forward(self, x):
            return self.layers(x)

    model = MLP().to(device)
    model.load_state_dict(torch.load(model_path, weights_only=True))

    # Preparing the inputs for dynamic batch sizing.
    inputs = [trt.Input(min_shape=(20, 8), 
                        opt_shape=(50, 8), 
                        max_shape=(64, 8), 
                        dtype=torch.float32)]

    # Trace the computation graph and compile to produce an optimized module
    trt_gm = trt.compile(model, ir="dynamo", inputs=inputs, require_full_compilation=True)

    print("SERVER: Compiled model.")

    @batch
    def _infer_fn(**inputs):
        features = np.squeeze(inputs["features"])
        stream = torch.cuda.Stream()
        with torch.no_grad(), torch.cuda.stream(stream):
            torch_inputs = torch.from_numpy(features).to(device)
            outputs = trt_gm(torch_inputs)
            return {
                "preds": outputs.cpu().numpy(),
            }

    with Triton() as triton:
        triton.bind(
            model_name="HousingModel",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="features", dtype=np.float32, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="preds", dtype=np.float32, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=64,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def stop_triton(signum, frame):
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

def start_triton(url, model_name, model_path):
    import socket
    import psutil
    from multiprocessing import Process
    from pytriton.client import ModelClient

    for conn in psutil.net_connections(kind="inet"):
        if conn.laddr.port == 8001:
            print(f"Process {conn.pid} is already running on port 8001. Please stop it before starting a new one.")
            return []

    hostname = socket.gethostname()
    process = Process(target=triton_server, args=(model_path,))
    process.start()

    client = ModelClient(url, model_name)
    ready = False
    while not ready:
        try:
            client.wait_for_server(5)
            ready = True
        except Exception as e:
            print(f"Waiting for server to be ready: {e}")
    
    return [(hostname, process.pid)]

#### Start Triton servers

To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU.  

In [47]:
def _use_stage_level_scheduling(spark, rdd):

    if spark.version < "3.4.0":
        raise Exception("Stage-level scheduling is not supported in Spark < 3.4.0")

    executor_cores = spark.conf.get("spark.executor.cores")
    assert executor_cores is not None, "spark.executor.cores is not set"
    executor_gpus = spark.conf.get("spark.executor.resource.gpu.amount")
    assert executor_gpus is not None and int(executor_gpus) <= 1, "spark.executor.resource.gpu.amount must be set and <= 1"

    from pyspark.resource.profile import ResourceProfileBuilder
    from pyspark.resource.requests import TaskResourceRequests

    spark_plugins = spark.conf.get("spark.plugins", " ")
    assert spark_plugins is not None
    spark_rapids_sql_enabled = spark.conf.get("spark.rapids.sql.enabled", "true")
    assert spark_rapids_sql_enabled is not None

    task_cores = (
        int(executor_cores)
        if "com.nvidia.spark.SQLPlugin" in spark_plugins
        and "true" == spark_rapids_sql_enabled.lower()
        else (int(executor_cores) // 2) + 1
    )

    task_gpus = 1.0
    treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
    rp = ResourceProfileBuilder().require(treqs).build
    print(f"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})")

    return rdd.withResources(rp)

**Specify the number of nodes in the cluster.**  
Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 2 nodes by default. 

In [48]:
num_nodes = 1  # Change based on cluster setup

In [None]:
url = "localhost"
model_name = "HousingModel"

# get absolute path to model
model_path = "{}/models/housing_model.pt".format(os.getcwd())

sc = spark.sparkContext
nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)

Reqesting stage-level resources: (cores=5, gpu=1.0)


In [50]:
pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(url, model_name, model_path)).collectAsMap()
print("Triton Server PIDs:\n", json.dumps(pids, indent=4))

[Stage 11:>                                                         (0 + 1) / 1]

Triton Server PIDs:
 {
    "cb4ae00-lcedt": 150085
}


                                                                                

#### Define client function

In [51]:
def triton_fn(url, model_name, init_timeout_s):
    from pytriton.client import ModelClient

    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, init_timeout_s=init_timeout_s) as client:
            result_data = client.infer_batch(inputs)
            return result_data["preds"]
        
    return infer_batch

### Run Inference

In [52]:
df = spark.read.parquet("datasets/california_housing")

In [53]:
columns = df.columns

In [54]:
regress = predict_batch_udf(partial(triton_fn, url="localhost", model_name="HousingModel", init_timeout_s=500),
                               input_tensor_shapes=[[8]],
                               return_type=FloatType(),
                               batch_size=50)

In [55]:
%%time
# first pass caches model/fn
predictions = df.withColumn("preds", regress(struct(*columns)))
preds = predictions.collect()

[Stage 13:>                                                         (0 + 8) / 8]

CPU times: user 179 ms, sys: 5.13 ms, total: 185 ms
Wall time: 4.22 s


                                                                                

In [56]:
%%time
predictions = df.withColumn("preds", regress(array(*columns)))
preds = predictions.collect()

CPU times: user 29.3 ms, sys: 6.09 ms, total: 35.4 ms
Wall time: 630 ms


In [57]:
%%time
predictions = df.withColumn("preds", regress(array(*columns)))
preds = predictions.collect()



CPU times: user 27.2 ms, sys: 2.04 ms, total: 29.2 ms
Wall time: 944 ms


                                                                                

In [58]:
predictions.show()

+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|     preds|
+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053| 1.4012802|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742| 1.8582058|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378| 1.4144654|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787| 2.4578812|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614| 1.1153543|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.008

#### Stop Triton Server on each executor

In [59]:
def stop_triton(pids):
    import os
    import socket
    import signal
    import time 
    
    hostname = socket.gethostname()
    pid = pids.get(hostname, None)
    assert pid is not None, f"Could not find pid for {hostname}"
    os.kill(pid, signal.SIGTERM)
    time.sleep(7)
    
    for _ in range(5):
        try:
            os.kill(pid, 0)
        except OSError:
            return [True]
        time.sleep(5)

    return [False]

shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)
shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()

Reqesting stage-level resources: (cores=5, gpu=1.0)


                                                                                

[True]

In [60]:
spark.stop()