<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark PyTorch Inference

### Regression

In this notebook, we will train an MLP to perform regression on the California housing dataset, and load it for distributed inference with Spark.  

Based on: https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-create-a-neural-network-for-regression-with-pytorch.md  

We also demonstrate accelerated inference via Torch-TensorRT model compilation.   

In [1]:
import torch
import os
import shutil
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler

In [2]:
os.mkdir('models') if not os.path.exists('models') else None

In [3]:
torch.__version__

'2.5.1+cu124'

### Load Dataset

Each label corresponds to the average house value in units of 100,000, which we'll try to predict using the following features:  
['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup', 'Latitude', 'Longitude']

In [4]:
X, y = fetch_california_housing(return_X_y=True)

In [5]:
class HousingDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, scale_data=True):
        if not torch.is_tensor(X) and not torch.is_tensor(y):
            # Apply scaling if necessary
            if scale_data:
                X = StandardScaler().fit_transform(X)
            self.X = torch.from_numpy(X.astype(np.float32))
            self.y = torch.from_numpy(y.astype(np.float32))

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.y[i]

In [6]:
dataset = HousingDataset(X, y)
trainloader = torch.utils.data.DataLoader(
    dataset, batch_size=10, shuffle=True, num_workers=1)

In [None]:
next(iter(trainloader))

[tensor([[ 6.5799e-01,  4.2594e-01, -1.4755e-01, -2.3638e-01, -4.0221e-01,
          -5.6793e-02,  8.8868e-01, -1.3528e+00],
         [ 6.7288e-01, -1.0043e+00,  5.7486e-01, -1.6537e-01, -3.3422e-01,
          -6.4971e-02, -1.2790e+00,  1.2327e+00],
         [-1.1616e-01,  2.8646e-02, -1.7830e-01, -2.3817e-01, -6.7154e-01,
          -3.6429e-02, -1.3258e+00,  1.2726e+00],
         [-3.2513e-01, -6.8648e-01, -3.4226e-01, -8.2805e-02,  5.1239e+00,
           2.6689e-02, -7.7338e-01,  8.3340e-01],
         [ 1.0892e-01, -1.2427e+00,  2.7819e-01, -8.7150e-02,  3.0158e-01,
          -1.8564e-02, -1.1245e+00,  1.1628e+00],
         [-8.6416e-02,  5.8485e-01, -7.8085e-02,  8.1655e-02, -6.7154e-01,
          -1.6053e-02, -3.4733e-01,  1.2577e+00],
         [-1.2463e-01,  1.0810e-01,  2.6662e-01, -1.0883e-01,  3.4839e-01,
          -2.3125e-02, -7.7338e-01,  1.3325e+00],
         [-9.2662e-01, -1.6400e+00, -2.4824e-01,  6.0041e-01,  6.3361e-01,
          -1.0926e-01, -8.8574e-01,  1.2826e+00],


### Create and Train Model

In [8]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
# Initialize the MLP
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
mlp = MLP().to(device)

# Define the loss function and optimizer
loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)

Using cuda device


In [10]:
# Run the training loop
for epoch in range(0, 5):  # 5 epochs at maximum

    # Print epoch
    print(f'Starting epoch {epoch+1}')

    # Set current loss value
    current_loss = 0.0

    # Iterate over the DataLoader for training data
    for i, data in enumerate(trainloader, 0):

        # Get and prepare inputs
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
        targets = targets.reshape((targets.shape[0], 1))

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = mlp(inputs)

        # Compute loss
        loss = loss_function(outputs, targets)

        # Perform backward pass
        loss.backward()

        # Perform optimization
        optimizer.step()

        # Print statistics
        current_loss += loss.item()
        if i % 200 == 0:
            print('Loss after mini-batch %5d: %.3f' %
                  (i + 1, current_loss / 500))
            current_loss = 0.0

# Process is complete.
print('Training process has finished.')

Starting epoch 1
Loss after mini-batch     1: 0.004
Loss after mini-batch   201: 0.701
Loss after mini-batch   401: 0.463
Loss after mini-batch   601: 0.329
Loss after mini-batch   801: 0.285
Loss after mini-batch  1001: 0.253
Loss after mini-batch  1201: 0.247
Loss after mini-batch  1401: 0.234
Loss after mini-batch  1601: 0.232
Loss after mini-batch  1801: 0.217
Loss after mini-batch  2001: 0.211
Starting epoch 2
Loss after mini-batch     1: 0.001
Loss after mini-batch   201: 0.205
Loss after mini-batch   401: 0.212
Loss after mini-batch   601: 0.206
Loss after mini-batch   801: 0.205
Loss after mini-batch  1001: 0.202
Loss after mini-batch  1201: 0.202
Loss after mini-batch  1401: 0.204
Loss after mini-batch  1601: 0.198
Loss after mini-batch  1801: 0.188
Loss after mini-batch  2001: 0.188
Starting epoch 3
Loss after mini-batch     1: 0.001
Loss after mini-batch   201: 0.197
Loss after mini-batch   401: 0.193
Loss after mini-batch   601: 0.196
Loss after mini-batch   801: 0.189
Loss

### Save Model State Dict
This saves the serialized object to disk using pickle.

In [11]:
torch.save(mlp.state_dict(), "models/housing_model.pt")
print("Saved PyTorch Model State to models/housing_model.pt")

Saved PyTorch Model State to models/housing_model.pt


### Save Model as TorchScript
This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python). 

In [None]:
scripted = torch.jit.script(mlp)
scripted.save("models/ts_housing_model.pt")
print("Saved TorchScript Model to models/ts_housing_model.pt")

Saved TorchScript Model to models/ts_housing_model.pt


### Load and Test from Model State

In [13]:
loaded_mlp = MLP().to(device)
loaded_mlp.load_state_dict(torch.load("models/housing_model.pt", weights_only=True))

<All keys matched successfully>

In [14]:
testX, testY = next(iter(trainloader))

In [None]:
print("Predictions:")
loaded_mlp(testX.to(device))

Predictions:


tensor([[2.3652],
        [1.8444],
        [2.4587],
        [3.1243],
        [2.2726],
        [2.1818],
        [1.5222],
        [0.5554],
        [2.2508],
        [3.5971]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [16]:
print("Labels:")
testY

Labels:


tensor([2.7370, 2.2110, 2.5360, 2.6330, 1.6540, 2.3360, 1.4600, 0.6590, 2.6380,
        3.6220])

### Load and Test from TorchScript

In [17]:
scripted_mlp = torch.jit.load("models/ts_housing_model.pt")

In [None]:
print("Predictions:")
scripted_mlp(testX.to(device)).flatten()

Predictions:


tensor([2.3652, 1.8444, 2.4587, 3.1243, 2.2726, 2.1818, 1.5222, 0.5554, 2.2508,
        3.5971], device='cuda:0', grad_fn=<ViewBackward0>)

### Compile using the Torch JIT Compiler
This leverages the [Torch-TensorRT inference compiler](https://pytorch.org/TensorRT/) for accelerated inference on GPUs using the `torch.compile` JIT interface under the hood. The compiler stack returns a [boxed-function](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/) that triggers compilation on the first call.  

Modules compiled in this fashion are [not serializable with pickle](https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089), so we cannot send the compiled model directly to Spark.  

(You may see a warning about modelopt quantization. This is safe to ignore, as [implicit quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#intro-quantization) is deprecated in the latest TensorRT. See [this link](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq.html) for a guide to explicit quantization.)

In [19]:
import torch_tensorrt as trt
import time

In [20]:
# Optional: set the filename for the TensorRT timing cache
timestamp = time.time()
timing_cache = f"/tmp/timing_cache-{timestamp}.bin"
with open(timing_cache, "wb") as f:
    pass

In [21]:
inputs_bs1 = torch.randn((10, 8), dtype=torch.float).to("cuda")
# This indicates dimension 0 of inputs_bs1 is dynamic with a range of values [1, 50]. No recompilation will happen when the batch size changes.
torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=50)
trt_model = trt.compile(
    loaded_mlp,
    ir="torch_compile",
    inputs=inputs_bs1,
    enabled_precisions={torch.float},
    timing_cache_path=timing_cache,
)

In [22]:
stream = torch.cuda.Stream()
with torch.no_grad(), torch.cuda.stream(stream):
    testX = testX.to(device)
    print("Predictions:")
    print(trt_model(testX))



Predictions:
tensor([[2.3652],
        [1.8444],
        [2.4587],
        [3.1243],
        [2.2726],
        [2.1818],
        [1.5222],
        [0.5554],
        [2.2508],
        [3.5971]], device='cuda:0')


### Compile using the Torch-TensorRT AOT Compiler
Alternatively, use the Torch-TensorRT Dynamo backend for Ahead-of-Time (AOT) compilation to eagerly optimize the model in an explicit compilation phase. We first export the model to produce a traced graph representing the Tensor computation in an AOT fashion, which produces a `ExportedProgram` object which can be [serialized and reloaded](https://pytorch.org/TensorRT/user_guide/saving_models.html). We can then compile this IR using the Torch-TensorRT AOT compiler for inference.   

[Read the docs](https://pytorch.org/TensorRT/user_guide/torch_tensorrt_explained.html) for more information on JIT vs AOT compilation.

In [23]:
example_inputs = (torch.randn((10, 8), dtype=torch.float).to("cuda"),)

# Mark dim 1 (batch size) as dynamic
batch = torch.export.Dim("batch", min=1, max=64)
# Produce traced graph in ExportedProgram format
exp_program = torch.export.export(loaded_mlp, args=example_inputs, dynamic_shapes={"x": {0: batch}})
# Compile the traced graph to produce an optimized module
trt_gm = trt.dynamo.compile(exp_program,
                            tuple(example_inputs),
                            enabled_precisions={torch.float},
                            timing_cache_path=timing_cache,
                            workspace_size=1<<30)

In [24]:
print(type(exp_program))
print(type(trt_gm))

<class 'torch.export.exported_program.ExportedProgram'>
<class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>


In [25]:
stream = torch.cuda.Stream()
with torch.no_grad(), torch.cuda.stream(stream):
    print("Predictions:")
    testX = testX.to(device)
    print(trt_gm(testX))

Predictions:
tensor([[2.3653],
        [1.8443],
        [2.4586],
        [3.1242],
        [2.2725],
        [2.1815],
        [1.5221],
        [0.5556],
        [2.2508],
        [3.5971]], device='cuda:0')


We can run the optimized module with a few different batch sizes (without recompilation!):

In [None]:
inputs = (torch.randn((10, 8), dtype=torch.float).cuda(),)
inputs_bs1 = (torch.randn((1, 8), dtype=torch.float).cuda(),)
inputs_bs50 = (torch.randn((50, 8), dtype=torch.float).cuda(),)

stream = torch.cuda.Stream()
with torch.no_grad(), torch.cuda.stream(stream):
    print("Output shapes:")
    print(trt_gm(*inputs).shape)
    print(trt_gm(*inputs_bs1).shape)
    print(trt_gm(*inputs_bs50).shape)

Output shapes:
torch.Size([10, 1])
torch.Size([1, 1])
torch.Size([50, 1])


We can serialize the ExportedProgram (a traced graph representing the model's forward function) using `torch.export.save` to be recompiled at a later date.

In [None]:
torch.export.save(exp_program, "models/trt_housing_model.ep")
print("Saved ExportedProgram to models/trt_housing_model.ep")

Saved ExportedProgram to models/trt_housing_model.ep


## PySpark

In [None]:
from pyspark.sql.functions import col, struct, pandas_udf, array
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark import SparkConf
import json
import pandas as pd

Check the cluster environment to handle any platform-specific Spark configurations.

In [29]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_IMAGE_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

#### Create Spark Session

For local standalone clusters, we'll connect to the cluster and create the Spark Session.  
For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).

In [30]:
conf = SparkConf()

if 'spark' not in globals():
    if on_standalone:
        import socket
        conda_env = os.environ.get("CONDA_PREFIX")
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")

    conf.set("spark.executor.cores", "8")
    conf.set("spark.task.resource.gpu.amount", "0.125")
    conf.set("spark.executor.resource.gpu.amount", "1")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")
    
spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

25/02/04 13:46:28 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
25/02/04 13:46:28 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/04 13:46:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


### Create Spark DataFrame from Pandas DataFrame

In [31]:
housing = fetch_california_housing()

In [32]:
X = StandardScaler().fit_transform(housing.data.astype(np.float32))

In [33]:
pdf = pd.DataFrame(X, columns=housing.feature_names)

In [34]:
schema = StructType([
    StructField("MedInc",FloatType(),True),
    StructField("HouseAge",FloatType(),True),
    StructField("AveRooms",FloatType(),True),
    StructField("AveBedrms",FloatType(),True),
    StructField("Population",FloatType(),True),
    StructField("AveOccup",FloatType(),True),
    StructField("Latitude",FloatType(),True),
    StructField("Longitude",FloatType(),True)
])

df = spark.createDataFrame(pdf, schema=schema).repartition(8)
df.show(truncate=12)

+------------+------------+-----------+------------+-----------+------------+----------+------------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|
+------------+------------+-----------+------------+-----------+------------+----------+------------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|
|  0.59671736|   0.5848523| 0.19346413|  -0.1371872|-0.19645879| 0.009964322|0.968

In [35]:
df.schema

StructType([StructField('MedInc', FloatType(), True), StructField('HouseAge', FloatType(), True), StructField('AveRooms', FloatType(), True), StructField('AveBedrms', FloatType(), True), StructField('Population', FloatType(), True), StructField('AveOccup', FloatType(), True), StructField('Latitude', FloatType(), True), StructField('Longitude', FloatType(), True)])

In [36]:
data_path = "spark-dl-datasets/california_housing"
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-datasets")
    data_path = "dbfs:/FileStore/" + data_path

df.write.mode("overwrite").parquet(data_path)

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):

- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays 
- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function

In [37]:
df = spark.read.parquet(data_path)

In [38]:
columns = df.columns

In [39]:
# get absolute path to model
model_path = "{}/models/trt_housing_model.ep".format(os.getcwd())

# For cloud environments, copy the model to the distributed file system.
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-models")
    dbfs_model_path = "/dbfs/FileStore/spark-dl-models/trt_housing_model.ep"
    shutil.copy(model_path, dbfs_model_path)
    model_path = dbfs_model_path
elif on_dataproc:
    # GCS is mounted at /mnt/gcs by the init script
    models_dir = "/mnt/gcs/spark-dl/models"
    os.mkdir(models_dir) if not os.path.exists(models_dir) else None
    gcs_model_path = models_dir + "/trt_housing_model.ep"
    shutil.copy(model_path, gcs_model_path)
    model_path = gcs_model_path

For inference on Spark, we'll load the ExportedProgram and compile the model with the Torch-TensorRT AOT compiler and cache on the executor. 

In [40]:
# A resource warning may occur due to unclosed file descriptors used by TensorRT across multiple PySpark daemon processes.
# These can be safely ignored as the resources will be cleaned up when the worker processes terminate.

import warnings
warnings.simplefilter("ignore", ResourceWarning)

In [41]:
def predict_batch_fn():
    import torch
    import torch_tensorrt as trt

    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device != "cuda":
        raise ValueError("This function uses the TensorRT model which requires a GPU device")

    example_inputs = (torch.randn((50, 8), dtype=torch.float).to("cuda"),)
    exp_program = torch.export.load(model_path)
    trt_gm = trt.dynamo.compile(exp_program,
                            tuple(example_inputs),
                            enabled_precisions={torch.float},
                            timing_cache_path=timing_cache,
                            workspace_size=1<<30)

    print("Model compiled.")
    
    def predict(inputs):
        stream = torch.cuda.Stream()
        with torch.no_grad(), torch.cuda.stream(stream), trt.logging.errors():
            print(f"Predict {inputs.shape}")
            torch_inputs = torch.from_numpy(inputs).to(device)
            outputs = trt_gm(torch_inputs) # .flatten()
            return outputs.detach().cpu().numpy()

    return predict

In [42]:
regress = predict_batch_udf(predict_batch_fn,
                             return_type=FloatType(),
                             input_tensor_shapes=[[8]],
                             batch_size=50)

In [43]:
%%time
preds = df.withColumn("preds", regress(struct(*columns)))
results = preds.collect()



CPU times: user 30.4 ms, sys: 13.1 ms, total: 43.5 ms
Wall time: 10.1 s


                                                                                

In [44]:
%%time
preds = df.withColumn("preds", regress(array(*columns)))
results = preds.collect()

CPU times: user 31.6 ms, sys: 7.39 ms, total: 39 ms
Wall time: 263 ms


In [45]:
%%time
preds = df.withColumn("preds", regress(array(*columns)))
results = preds.collect()

CPU times: user 28.7 ms, sys: 6.67 ms, total: 35.4 ms
Wall time: 296 ms


In [46]:
preds.show()

+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|    preds|
+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|1.3746364|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|1.8087528|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|1.4245079|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|2.3895802|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614|1.3616933|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|0.7

In [47]:
# This will clear the engine cache (containing previously compiled TensorRT engines) and reset the CUDA Context.
torch._dynamo.reset()

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-server.png" alt="drawing" width="700"/>

In [48]:
from functools import partial

Import the helper class from server_utils.py:

In [49]:
sc.addPyFile("server_utils.py")

from server_utils import TritonServerManager

Define the Triton Server function:

In [50]:
def triton_server(ports, model_path):
    import time
    import signal
    import numpy as np
    import torch
    import torch_tensorrt as trt
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton, TritonConfig
    from pyspark import TaskContext

    print(f"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    exp_program = torch.export.load(model_path)
    example_inputs = (torch.randn((50, 8), dtype=torch.float).to("cuda"),)
    trt_gm = trt.dynamo.compile(exp_program,
                            tuple(example_inputs),
                            enabled_precisions={torch.float},
                            workspace_size=1<<30)

    print("SERVER: Compiled model.")

    @batch
    def _infer_fn(**inputs):
        features = inputs["features"]
        if len(inputs["features"]) != 1:
            features = np.squeeze(features)
        stream = torch.cuda.Stream()
        with torch.no_grad(), torch.cuda.stream(stream):
            torch_inputs = torch.from_numpy(features).to(device)
            outputs = trt_gm(torch_inputs)
            return {
                "preds": outputs.cpu().numpy(),
            }

    workspace_path = f"/tmp/triton_{time.strftime('%m_%d_%M_%S')}"
    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])
    with Triton(config=triton_conf, workspace=workspace_path) as triton:
        triton.bind(
            model_name="HousingModel",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="features", dtype=np.float32, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="preds", dtype=np.float32, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=50,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def _stop_triton(signum, frame):
            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, _stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

#### Start Triton servers

The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:
- Find available ports for HTTP/gRPC/metrics
- Deploy a server on each node via stage-level scheduling
- Gracefully shutdown servers across nodes

In [54]:
model_name = "HousingModel"
server_manager = TritonServerManager(model_name=model_name, model_path=model_path)

In [None]:
# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}
server_manager.start_servers(triton_server)

2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)
2025-02-07 11:03:44,810 - INFO - Starting 1 servers.


                                                                                

{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}

#### Define client function

Get the hostname -> url mapping from the server manager:

In [None]:
host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url

Define the Triton inference function, which returns a predict function for batch inference through the server:

In [57]:
def triton_fn(model_name, host_to_url):
    import socket
    from pytriton.client import ModelClient

    url = host_to_url[socket.gethostname()]
    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, inference_timeout_s=240) as client:
            result_data = client.infer_batch(inputs)
            return result_data["preds"]
        
    return infer_batch

In [60]:
regress = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),
                               input_tensor_shapes=[[8]],
                               return_type=FloatType(),
                               batch_size=50)

### Run Inference

In [58]:
df = spark.read.parquet(data_path)

In [59]:
columns = df.columns

In [61]:
%%time
# first pass caches model/fn
predictions = df.withColumn("preds", regress(struct(*columns)))
preds = predictions.collect()

[Stage 16:>                                                         (0 + 8) / 8]

CPU times: user 25.8 ms, sys: 6.21 ms, total: 32.1 ms
Wall time: 2.37 s


                                                                                

In [62]:
%%time
predictions = df.withColumn("preds", regress(array(*columns)))
preds = predictions.collect()

[Stage 17:>                                                         (0 + 8) / 8]

CPU times: user 171 ms, sys: 3.76 ms, total: 174 ms
Wall time: 2.5 s


                                                                                

In [63]:
%%time
predictions = df.withColumn("preds", regress(array(*columns)))
preds = predictions.collect()

[Stage 18:>                                                         (0 + 8) / 8]

CPU times: user 24.4 ms, sys: 4.83 ms, total: 29.2 ms
Wall time: 1.97 s


                                                                                

In [64]:
predictions.show()

+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|    preds|
+------------+------------+-----------+------------+-----------+------------+----------+------------+---------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|1.3746364|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|1.8087528|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|1.4245079|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|2.3895802|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614|1.3616933|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|0.7

#### Stop Triton Server on each executor

In [65]:
server_manager.stop_servers()

                                                                                

[True]

In [66]:
if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state
    spark.stop()