Based on: https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-create-a-neural-network-for-regression-with-pytorch.md

In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f20ee95fc30>

In [3]:
X, y = fetch_california_housing(return_X_y=True)

In [4]:
class HousingDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, scale_data=True):
        if not torch.is_tensor(X) and not torch.is_tensor(y):
            # Apply scaling if necessary
            if scale_data:
                X = StandardScaler().fit_transform(X)
            self.X = torch.from_numpy(X.astype(np.float32))
            self.y = torch.from_numpy(y.astype(np.float32))

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.y[i]

In [5]:
dataset = HousingDataset(X, y)
trainloader = torch.utils.data.DataLoader(
    dataset, batch_size=10, shuffle=True, num_workers=1)

In [6]:
next(iter(trainloader))

[tensor([[ 0.4159,  0.3465,  0.2294, -0.1113, -0.3130,  0.0527, -0.5955,  0.3193],
         [-0.3706,  0.5849, -0.1514, -0.2040,  0.3943,  0.0550, -0.8155,  0.6687],
         [-0.2024, -0.8454,  0.1928,  0.0087,  0.5435, -0.0279,  0.9449, -1.2679],
         [ 0.1064, -1.9578,  0.2968,  0.0363,  2.7556, -0.0217, -0.4925,  0.7685],
         [ 0.1057,  1.0616,  0.1675, -0.0081, -0.3651, -0.0372, -0.6751,  0.7186],
         [ 0.3343, -1.4811, -0.7187,  0.2041,  0.0967,  0.0529, -0.8483,  0.8234],
         [-0.2691,  1.3000, -0.6491, -0.0872,  1.7074, -0.1372, -0.7312,  0.6088],
         [-0.0891, -0.3686,  0.0260, -0.1563,  0.3996,  0.0449,  1.1790, -1.3378],
         [ 0.1397, -0.2097,  0.1816, -0.1881,  0.4049, -0.0229, -0.7547,  1.2676],
         [-0.3399,  0.3465, -0.4621, -0.0519,  0.6115, -0.0284, -0.6704,  0.5189]]),
 tensor([1.8790, 1.1770, 3.3160, 1.5430, 2.3400, 1.5240, 2.8750, 1.2360, 1.2100,
         2.0530])]

In [7]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.layers(x)

In [8]:
# Initialize the MLP
mlp = MLP()

# Define the loss function and optimizer
loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)

In [9]:
# Run the training loop
for epoch in range(0, 5):  # 5 epochs at maximum

    # Print epoch
    print(f'Starting epoch {epoch+1}')

    # Set current loss value
    current_loss = 0.0

    # Iterate over the DataLoader for training data
    for i, data in enumerate(trainloader, 0):

        # Get and prepare inputs
        inputs, targets = data
        targets = targets.reshape((targets.shape[0], 1))

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = mlp(inputs)

        # Compute loss
        loss = loss_function(outputs, targets)

        # Perform backward pass
        loss.backward()

        # Perform optimization
        optimizer.step()

        # Print statistics
        current_loss += loss.item()
        if i % 200 == 0:
            print('Loss after mini-batch %5d: %.3f' %
                  (i + 1, current_loss / 500))
            current_loss = 0.0

# Process is complete.
print('Training process has finished.')

Starting epoch 1
Loss after mini-batch     1: 0.004
Loss after mini-batch   201: 0.733
Loss after mini-batch   401: 0.534
Loss after mini-batch   601: 0.403
Loss after mini-batch   801: 0.330
Loss after mini-batch  1001: 0.269
Loss after mini-batch  1201: 0.232
Loss after mini-batch  1401: 0.226
Loss after mini-batch  1601: 0.223
Loss after mini-batch  1801: 0.214
Loss after mini-batch  2001: 0.214
Starting epoch 2
Loss after mini-batch     1: 0.002
Loss after mini-batch   201: 0.211
Loss after mini-batch   401: 0.205
Loss after mini-batch   601: 0.199
Loss after mini-batch   801: 0.192
Loss after mini-batch  1001: 0.194
Loss after mini-batch  1201: 0.196
Loss after mini-batch  1401: 0.193
Loss after mini-batch  1601: 0.194
Loss after mini-batch  1801: 0.187
Loss after mini-batch  2001: 0.197
Starting epoch 3
Loss after mini-batch     1: 0.001
Loss after mini-batch   201: 0.190
Loss after mini-batch   401: 0.183
Loss after mini-batch   601: 0.181
Loss after mini-batch   801: 0.190
Loss

### Save Model

In [10]:
torch.save(mlp, "housing_model.pt")

In [11]:
scripted = torch.jit.script(mlp)
scripted.save("housing_model.ts")

### Load and Test Model

In [12]:
loaded_mlp = torch.load("housing_model.pt")

In [13]:
testX, testY = next(iter(trainloader))

In [14]:
loaded_mlp(testX)

tensor([[2.8778],
        [0.6233],
        [3.9021],
        [2.4543],
        [1.0209],
        [1.8093],
        [1.4593],
        [3.2933],
        [2.9263],
        [1.4790]], grad_fn=<AddmmBackward0>)

In [15]:
testY

tensor([2.8380, 0.5740, 5.0000, 2.0430, 1.2680, 1.8380, 1.4340, 2.6830, 1.6100,
        1.3050])

In [16]:
scripted_mlp = torch.jit.load("housing_model.ts")

In [17]:
scripted_mlp(testX).flatten()

tensor([2.8778, 0.6233, 3.9021, 2.4543, 1.0209, 1.8093, 1.4593, 3.2933, 2.9263,
        1.4790], grad_fn=<ReshapeAliasBackward0>)

### Columns as separate input variables

In [18]:
import numpy as np
import pandas as pd
import torch

from inspect import signature
from torch import nn
from torch.utils.data import DataLoader
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler

In [19]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f20ee95fc30>

In [20]:
housing = fetch_california_housing()

In [21]:
class HousingDataset2(torch.utils.data.Dataset):
    def __init__(self, X, y, scale_data=True):
        if not torch.is_tensor(X) and not torch.is_tensor(y):
            # Apply scaling if necessary
            if scale_data:
                X = StandardScaler().fit_transform(X)
            self.X = torch.from_numpy(X.astype(np.float32))
            self.y = torch.from_numpy(y.astype(np.float32))
            
            # Split dataset into separate variables
            self.MedInc = self.X[:,0]
            self.HouseAge = self.X[:,1]
            self.AveRooms = self.X[:,2]
            self.AveBedrms = self.X[:,3]
            self.Population = self.X[:,4]
            self.AveOccup = self.X[:,5]
            self.Latitude = self.X[:,6]
            self.Longitude = self.X[:,7]

    def __len__(self):
        return len(self.MedInc)

    def __getitem__(self, i):
        # Note: also returning combined X for ease of use later
        return self.MedInc[i], self.HouseAge[i], self.AveRooms[i], self.AveBedrms[i], self.Population[i], self.AveOccup[i], self.Latitude[i], self.Longitude[i], self.y[i]

In [22]:
dataset2 = HousingDataset2(housing.data, housing.target)
trainloader2 = torch.utils.data.DataLoader(dataset2, batch_size=10, shuffle=True, num_workers=1)

In [23]:
next(iter(trainloader2))

[tensor([ 0.4159, -0.3706, -0.2024,  0.1064,  0.1057,  0.3343, -0.2691, -0.0891,
          0.1397, -0.3399]),
 tensor([ 0.3465,  0.5849, -0.8454, -1.9578,  1.0616, -1.4811,  1.3000, -0.3686,
         -0.2097,  0.3465]),
 tensor([ 0.2294, -0.1514,  0.1928,  0.2968,  0.1675, -0.7187, -0.6491,  0.0260,
          0.1816, -0.4621]),
 tensor([-0.1113, -0.2040,  0.0087,  0.0363, -0.0081,  0.2041, -0.0872, -0.1563,
         -0.1881, -0.0519]),
 tensor([-0.3130,  0.3943,  0.5435,  2.7556, -0.3651,  0.0967,  1.7074,  0.3996,
          0.4049,  0.6115]),
 tensor([ 0.0527,  0.0550, -0.0279, -0.0217, -0.0372,  0.0529, -0.1372,  0.0449,
         -0.0229, -0.0284]),
 tensor([-0.5955, -0.8155,  0.9449, -0.4925, -0.6751, -0.8483, -0.7312,  1.1790,
         -0.7547, -0.6704]),
 tensor([ 0.3193,  0.6687, -1.2679,  0.7685,  0.7186,  0.8234,  0.6088, -1.3378,
          1.2676,  0.5189]),
 tensor([1.8790, 1.1770, 3.3160, 1.5430, 2.3400, 1.5240, 2.8750, 1.2360, 1.2100,
         2.0530])]

In [24]:
class MLP2(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, inc, age, rms, bdrms, pop, occup, lat, lon):       
        combined = torch.column_stack((inc, age, rms, bdrms, pop, occup, lat, lon))
        return self.layers(combined)

In [25]:
# Initialize the MLP
mlp2 = MLP2()

In [26]:
# Define the loss function and optimizer
loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(mlp2.parameters(), lr=1e-4)

In [27]:
# Run the training loop
for epoch in range(0, 5):  # 5 epochs at maximum

    # Print epoch
    print(f'Starting epoch {epoch+1}')

    # Set current loss value
    current_loss = 0.0

    # Iterate over the DataLoader for training data
    for i, data in enumerate(trainloader2, 0):

        # Get and prepare inputs
        a,b,c,d,e,f,g,h,targets = data
        targets = targets.reshape((targets.shape[0], 1))

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = mlp2(a,b,c,d,e,f,g,h)

        # Compute loss
        loss = loss_function(outputs, targets)

        # Perform backward pass
        loss.backward()

        # Perform optimization
        optimizer.step()

        # Print statistics
        current_loss += loss.item()
        if i % 200 == 0:
            print('Loss after mini-batch %5d: %.3f' %
                  (i + 1, current_loss / 500))
            current_loss = 0.0

# Process is complete.
print('Training process has finished.')

Starting epoch 1
Loss after mini-batch     1: 0.004
Loss after mini-batch   201: 0.733
Loss after mini-batch   401: 0.534
Loss after mini-batch   601: 0.403
Loss after mini-batch   801: 0.330
Loss after mini-batch  1001: 0.269
Loss after mini-batch  1201: 0.232
Loss after mini-batch  1401: 0.226
Loss after mini-batch  1601: 0.223
Loss after mini-batch  1801: 0.214
Loss after mini-batch  2001: 0.214
Starting epoch 2
Loss after mini-batch     1: 0.002
Loss after mini-batch   201: 0.211
Loss after mini-batch   401: 0.205
Loss after mini-batch   601: 0.199
Loss after mini-batch   801: 0.192
Loss after mini-batch  1001: 0.194
Loss after mini-batch  1201: 0.196
Loss after mini-batch  1401: 0.193
Loss after mini-batch  1601: 0.194
Loss after mini-batch  1801: 0.187
Loss after mini-batch  2001: 0.197
Starting epoch 3
Loss after mini-batch     1: 0.001
Loss after mini-batch   201: 0.190
Loss after mini-batch   401: 0.183
Loss after mini-batch   601: 0.181
Loss after mini-batch   801: 0.190
Loss

### Save Model

In [28]:
torch.save(mlp2, "housing_model2.pt")

In [29]:
scripted = torch.jit.script(mlp2)
scripted.save("housing_model2.ts")

### Load and Test Model

In [30]:
a,b,c,d,e,f,g,h,targets = next(iter(trainloader2))

In [31]:
loaded_mlp2 = torch.load("housing_model2.pt")

In [32]:
loaded_mlp2(a,b,c,d,e,f,g,h)

tensor([[2.8778],
        [0.6233],
        [3.9021],
        [2.4543],
        [1.0209],
        [1.8093],
        [1.4593],
        [3.2933],
        [2.9263],
        [1.4790]], grad_fn=<AddmmBackward0>)

In [33]:
print(signature(loaded_mlp2.forward))

(inc, age, rms, bdrms, pop, occup, lat, lon)


In [34]:
scripted_mlp2 = torch.jit.load("housing_model2.ts")

In [35]:
scripted_mlp2(a,b,c,d,e,f,g,h)

tensor([[2.8778],
        [0.6233],
        [3.9021],
        [2.4543],
        [1.0209],
        [1.8093],
        [1.4593],
        [3.2933],
        [2.9263],
        [1.4790]], grad_fn=<AddmmBackward0>)

## PySpark

### Convert dataset to Spark DataFrame

In [36]:
housing = fetch_california_housing()

In [37]:
X = StandardScaler().fit_transform(housing.data.astype(np.float32))

In [38]:
pdf = pd.DataFrame(X, columns=housing.feature_names)
pdf

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude
0,2.344766,0.982143,0.628559,-0.153758,-0.974429,-0.049597,1.052549,-1.327837
1,2.332238,-0.607019,0.327041,-0.263336,0.861439,-0.092512,1.043185,-1.322845
2,1.782699,1.856182,1.155620,-0.049016,-0.820777,-0.025843,1.038502,-1.332825
3,0.932967,1.856182,0.156966,-0.049833,-0.766028,-0.050329,1.038502,-1.337818
4,-0.012881,1.856182,0.344711,-0.032906,-0.759847,-0.085616,1.038502,-1.337818
...,...,...,...,...,...,...,...,...
20635,-1.216128,-0.289187,-0.155023,0.077354,-0.512592,-0.049110,1.801647,-0.758824
20636,-0.691593,-0.845393,0.276881,0.462365,-0.944405,0.005021,1.806329,-0.818721
20637,-1.142593,-0.924851,-0.090318,0.049414,-0.369537,-0.071734,1.778238,-0.823714
20638,-1.054583,-0.845393,-0.040211,0.158778,-0.604429,-0.091225,1.778238,-0.873626


In [39]:
foo = pdf.to_dict('series')

In [40]:
foo.keys()

dict_keys(['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup', 'Latitude', 'Longitude'])

In [41]:
pdf.dtypes

MedInc        float32
HouseAge      float32
AveRooms      float32
AveBedrms     float32
Population    float32
AveOccup      float32
Latitude      float32
Longitude     float32
dtype: object

In [42]:
from pyspark.sql.types import *

# Spark is somehow auto-converting Pandas float32 to DoubleType(), so forcing FloatType()
schema = StructType([
StructField("MedInc",FloatType(),True),
StructField("HouseAge",FloatType(),True),
StructField("AveRooms",FloatType(),True),
StructField("AveBedrms",FloatType(),True),
StructField("Population",FloatType(),True),
StructField("AveOccup",FloatType(),True),
StructField("Latitude",FloatType(),True),
StructField("Longitude",FloatType(),True)
])

df = spark.createDataFrame(pdf, schema=schema)
df.show(truncate=12)

[Stage 0:>                                                          (0 + 1) / 1]

+------------+----------+------------+------------+-----------+------------+---------+----------+
|      MedInc|  HouseAge|    AveRooms|   AveBedrms| Population|    AveOccup| Latitude| Longitude|
+------------+----------+------------+------------+-----------+------------+---------+----------+
|    2.344766| 0.9821427|  0.62855947| -0.15375753|-0.97442853|-0.049596533|1.0525488|-1.3278369|
|   2.3322382|-0.6070189|  0.32704142| -0.26333576|  0.8614389| -0.09251223|1.0431849|-1.3228445|
|   1.7826993| 1.8561815|   1.1556205|-0.049016476|-0.82077736|-0.025842525| 1.038502|-1.3328254|
|  0.93296736| 1.8561815|  0.15696616|-0.049833003|-0.76602805|-0.050329294| 1.038502|-1.3378178|
|-0.012881001| 1.8561815|  0.34471077|-0.032905966| -0.7598467| -0.08561575| 1.038502|-1.3378178|
| 0.087446585| 1.8561815| -0.26972958| 0.014669393|-0.89407074|-0.089618415| 1.038502|-1.3378178|
| -0.11136628| 1.8561815| -0.20091766| -0.30663314| -0.2927116| -0.09072491|1.0338209|-1.3378178|
| -0.39513668| 1.856

                                                                                

In [43]:
df.schema

StructType([StructField('MedInc', FloatType(), True), StructField('HouseAge', FloatType(), True), StructField('AveRooms', FloatType(), True), StructField('AveBedrms', FloatType(), True), StructField('Population', FloatType(), True), StructField('AveOccup', FloatType(), True), StructField('Latitude', FloatType(), True), StructField('Longitude', FloatType(), True)])

### Save DataFrame as parquet

In [44]:
df.write.mode("overwrite").parquet("california_housing")

                                                                                

## Inference using Spark DL API
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [45]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import array, struct, col
from pyspark.sql.types import ArrayType, FloatType

In [46]:
df = spark.read.parquet("california_housing")

In [47]:
columns = df.columns
columns

['MedInc',
 'HouseAge',
 'AveRooms',
 'AveBedrms',
 'Population',
 'AveOccup',
 'Latitude',
 'Longitude']

In [48]:
df.show()

+-----------+-----------+------------+-------------+------------+-------------+----------+------------+
|     MedInc|   HouseAge|    AveRooms|    AveBedrms|  Population|     AveOccup|  Latitude|   Longitude|
+-----------+-----------+------------+-------------+------------+-------------+----------+------------+
|  -1.669445|-0.20972852|  -1.1155425|  -0.17418891|   0.2070965|  -0.13437383| 0.5094524| -0.08001011|
| -0.8564016| -1.5605159| -0.53141737|  -0.02190494| -0.66006166|  -0.12997007| 0.5141335| -0.08001011|
| 0.73173314|-0.76593506|  0.67250663|  -0.10979619|  0.14175056|  -0.02296524| 0.5141335| -0.07002536|
|-0.44887984|  -1.719432|  0.14690235| -0.009905009| -0.06664997|   0.01090982|0.50476956|-0.075017735|
|-0.96920437| -1.0837674| 0.058284093|   0.07351444| -0.20793848| -0.011342554| 0.5094524| -0.08001011|
| -1.0906925| 0.26701993| -0.72011936|   -0.2243873|  -0.6574125|  0.021105729|0.49072453| -0.08001011|
|-0.90240705|  0.5848523|  0.24784403|   0.08091579|  -0.5205393

### Using TorchScript Model (single input)

In [49]:
# get absolute path to model
model_dir = "{}/housing_model.ts".format(os.getcwd())

In [50]:
def predict_batch_fn():
    import torch
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using {} device".format(device))
    
    scripted_mlp = torch.jit.load(model_dir)
    scripted_mlp.to(device)
    
    def predict(inputs):
        torch_inputs = torch.from_numpy(inputs).to(device)
        outputs = scripted_mlp(torch_inputs) # .flatten()
        return outputs.detach().numpy()

    return predict

In [51]:
classify = predict_batch_udf(predict_batch_fn,
                             return_type=FloatType(),
                             input_tensor_shapes=[[8]],
                             batch_size=50)

In [52]:
%%time
preds = df.withColumn("preds", classify(struct(*columns)))
results = preds.collect()

                                                                                

CPU times: user 200 ms, sys: 6.02 ms, total: 206 ms
Wall time: 3.24 s


In [53]:
%%time
preds = df.withColumn("preds", classify(array(*columns)))
results = preds.collect()

CPU times: user 48.7 ms, sys: 2.68 ms, total: 51.4 ms
Wall time: 540 ms


In [54]:
# should raise ValueError
# preds = df.withColumn("preds", classify(*columns))
# results = preds.collect()

In [55]:
preds.show()

+-----------+-----------+------------+-------------+------------+-------------+----------+------------+----------+
|     MedInc|   HouseAge|    AveRooms|    AveBedrms|  Population|     AveOccup|  Latitude|   Longitude|     preds|
+-----------+-----------+------------+-------------+------------+-------------+----------+------------+----------+
|  -1.669445|-0.20972852|  -1.1155425|  -0.17418891|   0.2070965|  -0.13437383| 0.5094524| -0.08001011| 0.8521974|
| -0.8564016| -1.5605159| -0.53141737|  -0.02190494| -0.66006166|  -0.12997007| 0.5141335| -0.08001011| 1.1964239|
| 0.73173314|-0.76593506|  0.67250663|  -0.10979619|  0.14175056|  -0.02296524| 0.5141335| -0.07002536| 1.7371099|
|-0.44887984|  -1.719432|  0.14690235| -0.009905009| -0.06664997|   0.01090982|0.50476956|-0.075017735| 1.0491724|
|-0.96920437| -1.0837674| 0.058284093|   0.07351444| -0.20793848| -0.011342554| 0.5094524| -0.08001011| 0.8659589|
| -1.0906925| 0.26701993| -0.72011936|   -0.2243873|  -0.6574125|  0.021105729|0

### Using TorchScript Model (separate input variables)

In [56]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import struct, col
from pyspark.sql.types import ArrayType, FloatType

In [57]:
df = spark.read.parquet("california_housing")

In [58]:
columns = df.columns
columns

['MedInc',
 'HouseAge',
 'AveRooms',
 'AveBedrms',
 'Population',
 'AveOccup',
 'Latitude',
 'Longitude']

In [59]:
# get absolute path to model
model2_dir = "{}/housing_model2.ts".format(os.getcwd())

In [60]:
def predict_batch_fn():
    import torch
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using {} device".format(device))
    scripted_mlp = torch.jit.load(model2_dir)
    scripted_mlp.to(device)
    
    def predict(inc, age, rms, bdrms, pop, occ, lat, lon):
        outputs = scripted_mlp(
            torch.from_numpy(inc).to(device),
            torch.from_numpy(age).to(device),
            torch.from_numpy(rms).to(device),
            torch.from_numpy(bdrms).to(device),
            torch.from_numpy(pop).to(device),
            torch.from_numpy(occ).to(device),
            torch.from_numpy(lat).to(device),
            torch.from_numpy(lon).to(device),
        )
        return outputs.detach().numpy()

    return predict

In [61]:
classify = predict_batch_udf(predict_batch_fn,
                             return_type=FloatType(),
                             batch_size=50)

In [62]:
%%time
# first pass caches model/fn
preds = df.withColumn("preds", classify(struct(*columns)))
results = preds.collect()

                                                                                

CPU times: user 192 ms, sys: 4.28 ms, total: 196 ms
Wall time: 1.89 s


In [63]:
# should fail with ValueError
# preds = df.withColumn("preds", classify(array(*columns)))
# results = preds.collect()

In [64]:
%%time
preds = df.withColumn("preds", classify(*columns))
results = preds.collect()

CPU times: user 64.7 ms, sys: 584 µs, total: 65.3 ms
Wall time: 363 ms


In [65]:
preds.show()

+-----------+-----------+------------+-------------+------------+-------------+----------+------------+----------+
|     MedInc|   HouseAge|    AveRooms|    AveBedrms|  Population|     AveOccup|  Latitude|   Longitude|     preds|
+-----------+-----------+------------+-------------+------------+-------------+----------+------------+----------+
|  -1.669445|-0.20972852|  -1.1155425|  -0.17418891|   0.2070965|  -0.13437383| 0.5094524| -0.08001011| 0.8521974|
| -0.8564016| -1.5605159| -0.53141737|  -0.02190494| -0.66006166|  -0.12997007| 0.5141335| -0.08001011| 1.1964239|
| 0.73173314|-0.76593506|  0.67250663|  -0.10979619|  0.14175056|  -0.02296524| 0.5141335| -0.07002536| 1.7371099|
|-0.44887984|  -1.719432|  0.14690235| -0.009905009| -0.06664997|   0.01090982|0.50476956|-0.075017735| 1.0491724|
|-0.96920437| -1.0837674| 0.058284093|   0.07351444| -0.20793848| -0.011342554| 0.5094524| -0.08001011| 0.8659589|
| -1.0906925| 0.26701993| -0.72011936|   -0.2243873|  -0.6574125|  0.021105729|0

### Using Triton Inference Server

Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [66]:
import numpy as np

from functools import partial
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import struct, col, array
from pyspark.sql.types import ArrayType, FloatType, Union, Dict

In [67]:
%%bash
# copy custom model to expected layout for Triton
rm -rf models
mkdir models
cp -r models_config/housing_model models
mkdir -p models/housing_model/1
cp housing_model.ts models/housing_model/1/model.pt

#### Start Triton Server on each executor

In [68]:
num_executors = 1
triton_models_dir = "{}/models".format(os.getcwd())
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:22.07-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="64M",
            volumes={triton_models_dir: {"bind": "/models", "mode": "ro"}}
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)
            
    return [True]

nodeRDD.barrier().mapPartitions(start_triton).collect()

                                                                                

[True]

### Run Inference

In [69]:
df = spark.read.parquet("california_housing")

In [70]:
columns = df.columns
columns

['MedInc',
 'HouseAge',
 'AveRooms',
 'AveBedrms',
 'Population',
 'AveOccup',
 'Latitude',
 'Longitude']

In [71]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool8),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [72]:
classify = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="housing_model"),
                             return_type=FloatType(),
                             input_tensor_shapes=[[8]],
                             batch_size=500)

In [73]:
%%time
# first pass caches model/fn
predictions = df.withColumn("preds", classify(struct(*columns)))
preds = predictions.collect()

[Stage 13:>                                                         (0 + 8) / 8]

CPU times: user 74.3 ms, sys: 9.28 ms, total: 83.6 ms
Wall time: 1.07 s


                                                                                

In [74]:
%%time
predictions = df.withColumn("preds", classify(array(*columns)))
preds = predictions.collect()

CPU times: user 197 ms, sys: 24.2 ms, total: 221 ms
Wall time: 400 ms


In [75]:
# should raise ValueError
# predictions = df.withColumn("preds", classify(*columns))
# preds = predictions.collect()

In [76]:
predictions.show()

+-----------+-----------+------------+-------------+------------+-------------+----------+------------+----------+
|     MedInc|   HouseAge|    AveRooms|    AveBedrms|  Population|     AveOccup|  Latitude|   Longitude|     preds|
+-----------+-----------+------------+-------------+------------+-------------+----------+------------+----------+
|  -1.669445|-0.20972852|  -1.1155425|  -0.17418891|   0.2070965|  -0.13437383| 0.5094524| -0.08001011|0.85219747|
| -0.8564016| -1.5605159| -0.53141737|  -0.02190494| -0.66006166|  -0.12997007| 0.5141335| -0.08001011| 1.1964238|
| 0.73173314|-0.76593506|  0.67250663|  -0.10979619|  0.14175056|  -0.02296524| 0.5141335| -0.07002536| 1.7371097|
|-0.44887984|  -1.719432|  0.14690235| -0.009905009| -0.06664997|   0.01090982|0.50476956|-0.075017735| 1.0491724|
|-0.96920437| -1.0837674| 0.058284093|   0.07351444| -0.20793848| -0.011342554| 0.5094524| -0.08001011| 0.8659591|
| -1.0906925| 0.26701993| -0.72011936|   -0.2243873|  -0.6574125|  0.021105729|0

#### Stop Triton Server on each executor

In [77]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.barrier().mapPartitions(stop_triton).collect()

                                                                                

[True]

In [78]:
spark.stop()