# PySpark PyTorch Inference

### Regression
Based on: https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-create-a-neural-network-for-regression-with-pytorch.md  

For the first MLP (array inputs) we'll also demonstrate accelerated inference on GPU with Torch-TensorRT. 

In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler

In [2]:
torch.__version__

'2.4.1+cu121'

In [3]:
torch.manual_seed(42)

<torch._C.Generator at 0x7813ac183330>

In [4]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [5]:
X, y = fetch_california_housing(return_X_y=True)

In [6]:
class HousingDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, scale_data=True):
        if not torch.is_tensor(X) and not torch.is_tensor(y):
            # Apply scaling if necessary
            if scale_data:
                X = StandardScaler().fit_transform(X)
            self.X = torch.from_numpy(X.astype(np.float32))
            self.y = torch.from_numpy(y.astype(np.float32))

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.y[i]

In [7]:
dataset = HousingDataset(X, y)
trainloader = torch.utils.data.DataLoader(
    dataset, batch_size=10, shuffle=True, num_workers=1)

In [8]:
next(iter(trainloader))

[tensor([[ 0.4159,  0.3465,  0.2294, -0.1113, -0.3130,  0.0527, -0.5955,  0.3193],
         [-0.3706,  0.5849, -0.1514, -0.2040,  0.3943,  0.0550, -0.8155,  0.6687],
         [-0.2024, -0.8454,  0.1928,  0.0087,  0.5435, -0.0279,  0.9449, -1.2679],
         [ 0.1064, -1.9578,  0.2968,  0.0363,  2.7556, -0.0217, -0.4925,  0.7685],
         [ 0.1057,  1.0616,  0.1675, -0.0081, -0.3651, -0.0372, -0.6751,  0.7186],
         [ 0.3343, -1.4811, -0.7187,  0.2041,  0.0967,  0.0529, -0.8483,  0.8234],
         [-0.2691,  1.3000, -0.6491, -0.0872,  1.7074, -0.1372, -0.7312,  0.6088],
         [-0.0891, -0.3686,  0.0260, -0.1563,  0.3996,  0.0449,  1.1790, -1.3378],
         [ 0.1397, -0.2097,  0.1816, -0.1881,  0.4049, -0.0229, -0.7547,  1.2676],
         [-0.3399,  0.3465, -0.4621, -0.0519,  0.6115, -0.0284, -0.6704,  0.5189]]),
 tensor([1.8790, 1.1770, 3.3160, 1.5430, 2.3400, 1.5240, 2.8750, 1.2360, 1.2100,
         2.0530])]

In [9]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.layers(x)

In [10]:
# Initialize the MLP
mlp = MLP().to(device)

# Define the loss function and optimizer
loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)

In [11]:
# Run the training loop
for epoch in range(0, 5):  # 5 epochs at maximum

    # Print epoch
    print(f'Starting epoch {epoch+1}')

    # Set current loss value
    current_loss = 0.0

    # Iterate over the DataLoader for training data
    for i, data in enumerate(trainloader, 0):

        # Get and prepare inputs
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
        targets = targets.reshape((targets.shape[0], 1))

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = mlp(inputs)

        # Compute loss
        loss = loss_function(outputs, targets)

        # Perform backward pass
        loss.backward()

        # Perform optimization
        optimizer.step()

        # Print statistics
        current_loss += loss.item()
        if i % 200 == 0:
            print('Loss after mini-batch %5d: %.3f' %
                  (i + 1, current_loss / 500))
            current_loss = 0.0

# Process is complete.
print('Training process has finished.')

Starting epoch 1
Loss after mini-batch     1: 0.004
Loss after mini-batch   201: 0.733
Loss after mini-batch   401: 0.534
Loss after mini-batch   601: 0.403
Loss after mini-batch   801: 0.330
Loss after mini-batch  1001: 0.269
Loss after mini-batch  1201: 0.232
Loss after mini-batch  1401: 0.226
Loss after mini-batch  1601: 0.223
Loss after mini-batch  1801: 0.214
Loss after mini-batch  2001: 0.214
Starting epoch 2
Loss after mini-batch     1: 0.002
Loss after mini-batch   201: 0.211
Loss after mini-batch   401: 0.205
Loss after mini-batch   601: 0.199
Loss after mini-batch   801: 0.192
Loss after mini-batch  1001: 0.194
Loss after mini-batch  1201: 0.196
Loss after mini-batch  1401: 0.193
Loss after mini-batch  1601: 0.194
Loss after mini-batch  1801: 0.187
Loss after mini-batch  2001: 0.197
Starting epoch 3
Loss after mini-batch     1: 0.001
Loss after mini-batch   201: 0.190
Loss after mini-batch   401: 0.183
Loss after mini-batch   601: 0.181
Loss after mini-batch   801: 0.190
Loss

### Save Model State Dict
This saves the serialized object to disk using pickle.

In [12]:
torch.save(mlp.state_dict(), "housing_model.pt")
print("Saved PyTorch Model State to housing_model.pt")

Saved PyTorch Model State to housing_model.pt


### Save Model as TorchScript
This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python). 

In [13]:
scripted = torch.jit.script(mlp)
scripted.save("ts_housing_model.pt")
print("Saved TorchScript Model to ts_housing_model.pt")

Saved TorchScript Model to ts_housing_model.pt


### Load and Test from Model State

In [14]:
loaded_mlp = MLP().to(device)
loaded_mlp.load_state_dict(torch.load("housing_model.pt", weights_only=True))

<All keys matched successfully>

In [15]:
testX, testY = next(iter(trainloader))

In [16]:
loaded_mlp(testX.to(device))

tensor([[1.7498],
        [3.0116],
        [1.1925],
        [4.0598],
        [2.0545],
        [2.9072],
        [2.0551],
        [4.6094],
        [1.0068],
        [1.1174]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [17]:
testY

tensor([3.5000, 2.9130, 0.7020, 5.0000, 1.7970, 2.7080, 2.1470, 5.0000, 0.6000,
        0.8480])

### Load and Test from TorchScript

In [18]:
scripted_mlp = torch.jit.load("ts_housing_model.pt")

In [19]:
scripted_mlp(testX.to(device)).flatten()

tensor([1.7498, 3.0116, 1.1925, 4.0598, 2.0545, 2.9072, 2.0551, 4.6094, 1.0068,
        1.1174], device='cuda:0', grad_fn=<ViewBackward0>)

### Compile using the Torch JIT Compiler
This leverages the [Torch-TensorRT inference compiler](https://pytorch.org/TensorRT/) for accelerated inference on GPUs using the `torch.compile` JIT interface under the hood. The compiler stack returns a [boxed-function](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/) that triggers compilation on the first call.  

Modules compiled in this fashion are [not serializable with pickle](https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089), so we cannot send the compiled model directly to Spark. Instead, we will recompile and cache the model on the executor. 

(You may see a warning about modelopt quantization. This is safe to ignore, as [implicit quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#intro-quantization) is deprecated in the latest TensorRT. See [this link](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq.html) for a guide to explicit quantization.)

In [None]:
import torch_tensorrt as trt
import time

In [None]:
# Optional: set the filename for the TensorRT timing cache
timestamp = time.time()
timing_cache = f"/tmp/timing_cache-{timestamp}.bin"
with open(timing_cache, "wb") as f:
    pass

In [20]:
inputs_bs1 = torch.randn((10, 8), dtype=torch.float).to("cuda")
# This indicates dimension 0 of inputs_bs1 is dynamic whose range of values is [1, 50]. No recompilation will happen when the batch size changes.
torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=50)
trt_model = trt.compile(
    loaded_mlp,
    ir="torch_compile",
    inputs=inputs_bs1,
    enabled_precisions={torch.float},
    timing_cache_path=timing_cache,
)

stream = torch.cuda.Stream()
with torch.no_grad(), torch.cuda.stream(stream):
    testX = testX.to(device)
    print(trt_model(testX))

INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Device not specified, using Torch default current device - cuda:0. If this is incorrect, please specify an input device, via the device keyword.
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=53

tensor([[1.7498],
        [3.0116],
        [1.1925],
        [4.0598],
        [2.0545],
        [2.9072],
        [2.0551],
        [4.6094],
        [1.0068],
        [1.1174]], device='cuda:0')


### Compile using the Torch-TensorRT AOT Compiler
Alternatively, use the Torch-TensorRT Dynamo backend for Ahead-of-Time (AOT) compilation to eagerly optimize the model in an explicit compilation phase. We first export the model to produce a traced graph representing the Tensor computation in an AOT fashion, which produces a `ExportedProgram` object which can be [serialized and reloaded](https://pytorch.org/TensorRT/user_guide/saving_models.html). We can then compile this IR using the Torch-TensorRT AOT compiler for inference.   

[Read the docs](https://pytorch.org/TensorRT/user_guide/torch_tensorrt_explained.html) for more information on JIT vs AOT compilation.

In [21]:
# Preparing the inputs for batch_size = 50. 
inputs = (torch.randn((10, 8), dtype=torch.float).cuda(),)

# Produce traced graph in the ExportedProgram format
exp_program = trt.dynamo.trace(loaded_mlp, inputs)
# Compile the traced graph to produce an optimized module
trt_gm = trt.dynamo.compile(exp_program,
                            inputs=inputs,
                            timing_cache_path=timing_cache)

stream = torch.cuda.Stream()
with torch.no_grad(), torch.cuda.stream(stream):
    testX = testX.to(device)
    print(trt_model(testX))

INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache.bin')

INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +

tensor([[1.7498],
        [3.0116],
        [1.1925],
        [4.0598],
        [2.0545],
        [2.9072],
        [2.0551],
        [4.6094],
        [1.0068],
        [1.1174]], device='cuda:0')


We can save the compiled model using `torch_tensorrt.save`. Unfortunately, serializing the model to be reloaded at a later date currently only supports *static inputs*.

In [22]:
with torch.cuda.stream(stream):
    trt.save(trt_gm, "trt_model_aot.ep", inputs=[torch.randn((10, 8), dtype=torch.float).to("cuda")])
    print("Saved AOT compiled TensorRT model to trt_model_aot.ep")

  engine_node = gm.graph.get_attr(engine_name)




Saved AOT compiled TensorRT model to trt_model_aot.ep


### Columns as separate input variables

In [23]:
import numpy as np
import pandas as pd
import torch

from inspect import signature
from torch import nn
from torch.utils.data import DataLoader
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler

In [24]:
torch.manual_seed(42)

<torch._C.Generator at 0x7813ac183330>

In [25]:
housing = fetch_california_housing()

In [26]:
class HousingDataset2(torch.utils.data.Dataset):
    def __init__(self, X, y, scale_data=True):
        if not torch.is_tensor(X) and not torch.is_tensor(y):
            # Apply scaling if necessary
            if scale_data:
                X = StandardScaler().fit_transform(X)
            self.X = torch.from_numpy(X.astype(np.float32))
            self.y = torch.from_numpy(y.astype(np.float32))
            
            # Split dataset into separate variables
            self.MedInc = self.X[:,0]
            self.HouseAge = self.X[:,1]
            self.AveRooms = self.X[:,2]
            self.AveBedrms = self.X[:,3]
            self.Population = self.X[:,4]
            self.AveOccup = self.X[:,5]
            self.Latitude = self.X[:,6]
            self.Longitude = self.X[:,7]

    def __len__(self):
        return len(self.MedInc)

    def __getitem__(self, i):
        # Note: also returning combined X for ease of use later
        return self.MedInc[i], self.HouseAge[i], self.AveRooms[i], self.AveBedrms[i], self.Population[i], self.AveOccup[i], self.Latitude[i], self.Longitude[i], self.y[i]

In [27]:
dataset2 = HousingDataset2(housing.data, housing.target)
trainloader2 = torch.utils.data.DataLoader(dataset2, batch_size=10, shuffle=True, num_workers=1)

In [28]:
next(iter(trainloader2))

[tensor([ 0.4159, -0.3706, -0.2024,  0.1064,  0.1057,  0.3343, -0.2691, -0.0891,
          0.1397, -0.3399]),
 tensor([ 0.3465,  0.5849, -0.8454, -1.9578,  1.0616, -1.4811,  1.3000, -0.3686,
         -0.2097,  0.3465]),
 tensor([ 0.2294, -0.1514,  0.1928,  0.2968,  0.1675, -0.7187, -0.6491,  0.0260,
          0.1816, -0.4621]),
 tensor([-0.1113, -0.2040,  0.0087,  0.0363, -0.0081,  0.2041, -0.0872, -0.1563,
         -0.1881, -0.0519]),
 tensor([-0.3130,  0.3943,  0.5435,  2.7556, -0.3651,  0.0967,  1.7074,  0.3996,
          0.4049,  0.6115]),
 tensor([ 0.0527,  0.0550, -0.0279, -0.0217, -0.0372,  0.0529, -0.1372,  0.0449,
         -0.0229, -0.0284]),
 tensor([-0.5955, -0.8155,  0.9449, -0.4925, -0.6751, -0.8483, -0.7312,  1.1790,
         -0.7547, -0.6704]),
 tensor([ 0.3193,  0.6687, -1.2679,  0.7685,  0.7186,  0.8234,  0.6088, -1.3378,
          1.2676,  0.5189]),
 tensor([1.8790, 1.1770, 3.3160, 1.5430, 2.3400, 1.5240, 2.8750, 1.2360, 1.2100,
         2.0530])]

In [29]:
class MLP2(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, inc, age, rms, bdrms, pop, occup, lat, lon):       
        combined = torch.column_stack((inc, age, rms, bdrms, pop, occup, lat, lon))
        return self.layers(combined)

In [30]:
# Initialize the MLP
mlp2 = MLP2().to(device)

In [31]:
# Define the loss function and optimizer
loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(mlp2.parameters(), lr=1e-4)

In [32]:
# Run the training loop
for epoch in range(0, 5):  # 5 epochs at maximum

    # Print epoch
    print(f'Starting epoch {epoch+1}')

    # Set current loss value
    current_loss = 0.0

    # Iterate over the DataLoader for training data
    for i, data in enumerate(trainloader2, 0):

        # Get and prepare inputs
        a,b,c,d,e,f,g,h,targets = data
        a,b,c,d,e,f,g,h,targets = a.to(device),b.to(device),c.to(device),d.to(device),e.to(device),f.to(device),g.to(device),h.to(device),targets.to(device)
        targets = targets.reshape((targets.shape[0], 1))

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = mlp2(a,b,c,d,e,f,g,h)

        # Compute loss
        loss = loss_function(outputs, targets)

        # Perform backward pass
        loss.backward()

        # Perform optimization
        optimizer.step()

        # Print statistics
        current_loss += loss.item()
        if i % 200 == 0:
            print('Loss after mini-batch %5d: %.3f' %
                  (i + 1, current_loss / 500))
            current_loss = 0.0

# Process is complete.
print('Training process has finished.')

Starting epoch 1
Loss after mini-batch     1: 0.004
Loss after mini-batch   201: 0.733
Loss after mini-batch   401: 0.534
Loss after mini-batch   601: 0.403
Loss after mini-batch   801: 0.330
Loss after mini-batch  1001: 0.269
Loss after mini-batch  1201: 0.232
Loss after mini-batch  1401: 0.226
Loss after mini-batch  1601: 0.223
Loss after mini-batch  1801: 0.214
Loss after mini-batch  2001: 0.214
Starting epoch 2
Loss after mini-batch     1: 0.002
Loss after mini-batch   201: 0.211
Loss after mini-batch   401: 0.205
Loss after mini-batch   601: 0.199
Loss after mini-batch   801: 0.192
Loss after mini-batch  1001: 0.194
Loss after mini-batch  1201: 0.196
Loss after mini-batch  1401: 0.193
Loss after mini-batch  1601: 0.194
Loss after mini-batch  1801: 0.187
Loss after mini-batch  2001: 0.197
Starting epoch 3
Loss after mini-batch     1: 0.001
Loss after mini-batch   201: 0.190
Loss after mini-batch   401: 0.183
Loss after mini-batch   601: 0.181
Loss after mini-batch   801: 0.190
Loss

### Save Model State Dict

In [33]:
torch.save(mlp2.state_dict(), "housing_model2.pt")
print("Saved PyTorch Model State to housing_model2.pt")

Saved PyTorch Model State to housing_model2.pt


### Save Model as TorchScript

In [34]:
scripted = torch.jit.script(mlp2)
scripted.save("ts_housing_model2.pt")
print("Saved TorchScript Model to ts_housing_model2.pt")

Saved TorchScript Model to ts_housing_model2.pt


### Load and Test from Model State

In [35]:
a,b,c,d,e,f,g,h,targets = next(iter(trainloader2))
a,b,c,d,e,f,g,h,targets = a.to(device), b.to(device), c.to(device), d.to(device), e.to(device), f.to(device), g.to(device), h.to(device), targets.to(device)

In [36]:
loaded_mlp2 = MLP2().to(device)
loaded_mlp2.load_state_dict(torch.load("housing_model2.pt", weights_only=True))

<All keys matched successfully>

In [37]:
loaded_mlp2(a,b,c,d,e,f,g,h)

tensor([[2.8778],
        [0.6233],
        [3.9021],
        [2.4543],
        [1.0209],
        [1.8093],
        [1.4593],
        [3.2933],
        [2.9263],
        [1.4790]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [38]:
print(signature(loaded_mlp2.forward))

(inc, age, rms, bdrms, pop, occup, lat, lon)


### Load and Test from TorchScript

In [39]:
scripted_mlp2 = torch.jit.load("ts_housing_model2.pt")

In [40]:
scripted_mlp2(a,b,c,d,e,f,g,h)

tensor([[2.8778],
        [0.6233],
        [3.9021],
        [2.4543],
        [1.0209],
        [1.8093],
        [1.4593],
        [3.2933],
        [2.9263],
        [1.4790]], device='cuda:0', grad_fn=<AddmmBackward0>)

## PySpark

### Convert dataset to Spark DataFrame

In [41]:
housing = fetch_california_housing()

In [42]:
X = StandardScaler().fit_transform(housing.data.astype(np.float32))

In [43]:
pdf = pd.DataFrame(X, columns=housing.feature_names)
pdf

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude
0,2.344766,0.982143,0.628559,-0.153758,-0.974429,-0.049597,1.052549,-1.327837
1,2.332238,-0.607019,0.327041,-0.263336,0.861439,-0.092512,1.043185,-1.322845
2,1.782699,1.856182,1.155620,-0.049016,-0.820777,-0.025843,1.038502,-1.332825
3,0.932967,1.856182,0.156966,-0.049833,-0.766028,-0.050329,1.038502,-1.337818
4,-0.012881,1.856182,0.344711,-0.032906,-0.759847,-0.085616,1.038502,-1.337818
...,...,...,...,...,...,...,...,...
20635,-1.216128,-0.289187,-0.155023,0.077354,-0.512592,-0.049110,1.801647,-0.758824
20636,-0.691593,-0.845393,0.276881,0.462365,-0.944405,0.005021,1.806329,-0.818721
20637,-1.142593,-0.924851,-0.090318,0.049414,-0.369537,-0.071734,1.778238,-0.823714
20638,-1.054583,-0.845393,-0.040211,0.158778,-0.604429,-0.091225,1.778238,-0.873626


In [44]:
foo = pdf.to_dict('series')

In [45]:
foo.keys()

dict_keys(['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup', 'Latitude', 'Longitude'])

In [46]:
pdf.dtypes

MedInc        float32
HouseAge      float32
AveRooms      float32
AveBedrms     float32
Population    float32
AveOccup      float32
Latitude      float32
Longitude     float32
dtype: object

In [47]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark import SparkConf

In [48]:
import os
conda_env = os.environ.get("CONDA_PREFIX")

conf = SparkConf()
if 'spark' not in globals():
    # If Spark is not already started with Jupyter, attach to Spark Standalone
    import socket
    hostname = socket.gethostname()
    conf.setMaster(f"spark://{hostname}:7077") # assuming Master is on default port 7077
conf.set("spark.task.maxFailures", "1")
conf.set("spark.driver.memory", "8g")
conf.set("spark.executor.memory", "8g")
conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
conf.set("spark.python.worker.reuse", "true")
# Create Spark Session
spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

24/10/08 00:36:22 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
24/10/08 00:36:22 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/08 00:36:22 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [49]:
# Spark is somehow auto-converting Pandas float32 to DoubleType(), so forcing FloatType()
schema = StructType([
StructField("MedInc",FloatType(),True),
StructField("HouseAge",FloatType(),True),
StructField("AveRooms",FloatType(),True),
StructField("AveBedrms",FloatType(),True),
StructField("Population",FloatType(),True),
StructField("AveOccup",FloatType(),True),
StructField("Latitude",FloatType(),True),
StructField("Longitude",FloatType(),True)
])

df = spark.createDataFrame(pdf, schema=schema).repartition(8)
df.show(truncate=12)

  if is_categorical_dtype(series.dtype):

                                                                                

+------------+------------+-----------+------------+-----------+------------+----------+------------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|
+------------+------------+-----------+------------+-----------+------------+----------+------------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|
|  0.59671736|   0.5848523| 0.19346413|  -0.1371872|-0.19645879| 0.009964322|0.968

In [50]:
df.schema

StructType([StructField('MedInc', FloatType(), True), StructField('HouseAge', FloatType(), True), StructField('AveRooms', FloatType(), True), StructField('AveBedrms', FloatType(), True), StructField('Population', FloatType(), True), StructField('AveOccup', FloatType(), True), StructField('Latitude', FloatType(), True), StructField('Longitude', FloatType(), True)])

### Save DataFrame as parquet

In [51]:
df.write.mode("overwrite").parquet("california_housing")

## Inference using Spark DL API
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [52]:
import os
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import array, struct, col
from pyspark.sql.types import ArrayType, FloatType

In [53]:
df = spark.read.parquet("california_housing")

In [54]:
columns = df.columns
columns

['MedInc',
 'HouseAge',
 'AveRooms',
 'AveBedrms',
 'Population',
 'AveOccup',
 'Latitude',
 'Longitude']

In [55]:
df.show()

+------------+------------+-----------+------------+-----------+------------+----------+------------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|
+------------+------------+-----------+------------+-----------+------------+----------+------------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.0083897|
|  0.59671736|   0.5848523| 0.19346413|  -0.1371872|-0.19645879| 0.009964322|0.968

### Using TensorRT Model (single input)

In [56]:
# get absolute path to model
model_dir = "{}/housing_model.pt".format(os.getcwd())

In [57]:
def predict_batch_fn():
    import torch
    import torch_tensorrt as trt

    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device != "cuda":
        raise ValueError("This function uses the TensorRT model which requires a GPU device")

    # Define model
    class MLP(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.Sequential(
                nn.Linear(8, 64),
                nn.ReLU(),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, 1)
            )

        def forward(self, x):
            return self.layers(x)

    model = MLP().to(device)
    model.load_state_dict(torch.load(model_dir, weights_only=True))

    # Preparing the inputs for dynamic batch sizing.
    inputs = [trt.Input(min_shape=(20, 8), 
                opt_shape=(50, 8), 
                max_shape=(64, 8), 
                dtype=torch.float32)]

    # Trace the computation graph and compile to produce an optimized module
    trt_gm = trt.compile(model, ir="dynamo", inputs=inputs)
    
    def predict(inputs):
        stream = torch.cuda.Stream()
        with torch.no_grad(), torch.cuda.stream(stream), trt.logging.errors():
            torch_inputs = torch.from_numpy(inputs).to(device)
            outputs = trt_gm(torch_inputs) # .flatten()
            return outputs.detach().cpu().numpy()

    return predict

In [59]:
classify = predict_batch_udf(predict_batch_fn,
                             return_type=FloatType(),
                             input_tensor_shapes=[[8]],
                             batch_size=50)

In [60]:
%%time
preds = df.withColumn("preds", classify(struct(*columns)))
results = preds.collect()

                                                                                

CPU times: user 148 ms, sys: 17.9 ms, total: 166 ms
Wall time: 8.28 s


In [61]:
%%time
preds = df.withColumn("preds", classify(array(*columns)))
results = preds.collect()

CPU times: user 31.2 ms, sys: 5.84 ms, total: 37.1 ms
Wall time: 257 ms


In [62]:
# should raise ValueError
# preds = df.withColumn("preds", classify(*columns))
# results = preds.collect()

In [63]:
preds.show()

+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|     preds|
+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053| 1.3980268|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742| 1.7447104|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378|  1.439564|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787| 2.4199378|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614| 1.2448893|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.008

### Using TorchScript Model (separate input variables)

In [64]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import struct, col
from pyspark.sql.types import ArrayType, FloatType

In [65]:
df = spark.read.parquet("california_housing")

In [66]:
columns = df.columns
columns

['MedInc',
 'HouseAge',
 'AveRooms',
 'AveBedrms',
 'Population',
 'AveOccup',
 'Latitude',
 'Longitude']

In [67]:
# get absolute path to model
model2_dir = "{}/ts_housing_model2.pt".format(os.getcwd())

In [68]:
def predict_batch_fn():
    import torch

    device = "cuda" if torch.cuda.is_available() else "cpu"
    scripted_mlp = torch.jit.load(model2_dir)
    scripted_mlp.to(device)
    
    def predict(inc, age, rms, bdrms, pop, occ, lat, lon):
        # print input shape
        outputs = scripted_mlp(
            torch.from_numpy(inc).to(device),
            torch.from_numpy(age).to(device),
            torch.from_numpy(rms).to(device),
            torch.from_numpy(bdrms).to(device),
            torch.from_numpy(pop).to(device),
            torch.from_numpy(occ).to(device),
            torch.from_numpy(lat).to(device),
            torch.from_numpy(lon).to(device),
        )
        return outputs.detach().cpu().numpy()

    return predict

In [69]:
classify = predict_batch_udf(predict_batch_fn,
                             return_type=FloatType(),
                             batch_size=50)

In [70]:
%%time
# first pass caches model/fn
preds = df.withColumn("preds", classify(struct(*columns)))
results = preds.collect()



CPU times: user 16.7 ms, sys: 5.44 ms, total: 22.1 ms
Wall time: 1.13 s


                                                                                

In [71]:
%%time
preds = df.withColumn("preds", classify(*columns))
results = preds.collect()

CPU times: user 19.6 ms, sys: 8.68 ms, total: 28.3 ms
Wall time: 451 ms


In [72]:
preds.show()

+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|     preds|
+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053| 1.3979516|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742| 1.7442212|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378| 1.4398992|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787| 2.4199052|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614| 1.2446644|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.008

### Using Triton Inference Server

Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [73]:
import numpy as np

from functools import partial
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import struct, col, array
from pyspark.sql.types import ArrayType, FloatType, Union, Dict

In [74]:
%%bash
# copy custom model to expected layout for Triton
rm -rf models
mkdir models
cp -r models_config/housing_model models
mkdir -p models/housing_model/1
cp ts_housing_model.pt models/housing_model/1/model.pt

#### Start Triton Server on each executor

In [75]:
num_executors = 1
triton_models_dir = "{}/models".format(os.getcwd())
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:24.08-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="64M",
            volumes={triton_models_dir: {"bind": "/models", "mode": "ro"}}
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)
            
    return [True]

nodeRDD.barrier().mapPartitions(start_triton).collect()

                                                                                

[True]

### Run Inference

In [76]:
df = spark.read.parquet("california_housing")

In [77]:
columns = df.columns
columns

['MedInc',
 'HouseAge',
 'AveRooms',
 'AveBedrms',
 'Population',
 'AveOccup',
 'Latitude',
 'Longitude']

In [78]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool_),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [79]:
classify = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="housing_model"),
                             return_type=FloatType(),
                             input_tensor_shapes=[[8]],
                             batch_size=500)

In [80]:
%%time
# first pass caches model/fn
predictions = df.withColumn("preds", classify(struct(*columns)))
preds = predictions.collect()

CPU times: user 15.3 ms, sys: 4.33 ms, total: 19.6 ms
Wall time: 461 ms


In [81]:
%%time
predictions = df.withColumn("preds", classify(array(*columns)))
preds = predictions.collect()

CPU times: user 30.2 ms, sys: 5.78 ms, total: 36 ms
Wall time: 200 ms


In [82]:
# should raise ValueError
# predictions = df.withColumn("preds", classify(*columns))
# preds = predictions.collect()

In [83]:
predictions.show()

+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|      MedInc|    HouseAge|   AveRooms|   AveBedrms| Population|    AveOccup|  Latitude|   Longitude|     preds|
+------------+------------+-----------+------------+-----------+------------+----------+------------+----------+
|  0.20909257|  -1.1632254| 0.38946992|  0.04609274| -0.9806099| -0.07099328|0.61245227|-0.020113053| 1.3979516|
|-0.098627955|  0.34647804| 0.27216315|  -0.0129226| -0.6953838| -0.05380849| 1.0665938|  -1.2479742| 1.7442212|
| -0.66006273|   1.0616008|-0.55292207| -0.48945764|-0.13641118| 0.028952759| 1.1040496|  -1.3827378| 1.4398992|
|  0.08218294|   0.5848523|-0.13912922| -0.14707813|-0.19116047| -0.07136432|0.96827507|  -1.3028787| 2.4199052|
|   0.0784456|  -1.4810578| 0.57265776|  0.32067496|  1.0345173|-0.024157424| 1.4411427| -0.52423614| 1.2446644|
| -0.82318723| -0.36864465| 0.07829511|  -0.1808107|-0.67242444|-0.061470542| 1.9374212|  -1.008

#### Stop Triton Server on each executor

In [84]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.barrier().mapPartitions(stop_triton).collect()

                                                                                

[True]

In [85]:
spark.stop()