<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark DL Inference
### Conditional generation with Huggingface

In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation.  
From: https://huggingface.co/docs/transformers/model_doc/t5

In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.
# See (https://github.com/huggingface/transformers/issues/5486) for more info. 
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [2]:
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

task_prefix = "translate English to German: "

lines = [
    "The house is wonderful",
    "Welcome to NYC",
    "HuggingFace is a company"
]

input_sequences = [task_prefix + l for l in lines]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [3]:
inputs = tokenizer(input_sequences,
                      padding=True, 
                      return_tensors="pt")

outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=128)

In [4]:
[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

['Das Haus ist wunderbar',
 'Willkommen in NYC',
 'HuggingFace ist ein Unternehmen']

## PySpark

In [5]:
from pyspark.sql.types import *
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, struct
from pyspark.ml.functions import predict_batch_udf

In [6]:
import json
import pandas as pd
import datasets
from datasets import load_dataset
datasets.disable_progress_bars()

Check the cluster environment to handle any platform-specific Spark configurations.

In [7]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_IMAGE_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

#### Create Spark Session

For local standalone clusters, we'll connect to the cluster and create the Spark Session.  
For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).

In [8]:
conf = SparkConf()

if 'spark' not in globals():
    if on_standalone:
        import socket
        conda_env = os.environ.get("CONDA_PREFIX")
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")

    conf.set("spark.executor.cores", "8")
    conf.set("spark.task.resource.gpu.amount", "0.125")
    conf.set("spark.executor.resource.gpu.amount", "1")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")

conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")
spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

25/02/04 13:34:55 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
25/02/04 13:34:55 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/04 13:34:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Load the IMBD Movie Reviews dataset from Huggingface.

In [9]:
dataset = load_dataset("imdb", split="test")
dataset = dataset.to_pandas().drop(columns="label")

#### Create PySpark DataFrame

In [10]:
df = spark.createDataFrame(dataset).repartition(8)
df.schema

StructType([StructField('text', StringType(), True)])

In [11]:
df.count()

25000

In [12]:
df.take(1)

25/02/04 13:35:02 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.


[Row(text="Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.<br /><br />The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.<br /><br />The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.<br /><br />I really got nothing much left to say except, give us back CKY2K, cause Bam suck..<br /><br />I enjoy watching Steve-o, Knoxville etc. a thousand times more.")]

In [13]:
data_path = "spark-dl-datasets/imdb_test"
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-datasets")
    data_path = "dbfs:/FileStore/" + data_path

df.write.mode("overwrite").parquet(data_path)

25/02/04 13:35:02 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.


#### Load and preprocess DataFrame

Define our preprocess function. We'll take the first sentence from each sample as our input for translation.

In [14]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [15]:
# Limit to N rows, since this can be slow
df = spark.read.parquet(data_path).limit(512).repartition(8)
df.show(truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                                text|
+----------------------------------------------------------------------------------------------------+
|The only reason I'm even giving this movie a 4 is because it was made in to an episode of Mystery...|
|Awkward disaster mishmash has a team of scavengers coming across the overturned S.S. Poseidon, ho...|
|Here is a fantastic concept for a film - a series of meteors crash into a small town and the resu...|
|I walked out of the cinema having suffered this film after 30 mins. I left two friends pinned in ...|
|A wildly uneven film where the major problem is the uneasy mix of comedy and thriller. To me, the...|
|Leonard Rossiter and Frances de la Tour carry this film, not without a struggle, as the script wa...|
|A good cast... A good idea but turns out it is flawed as hypnosis is not

Append a prefix to tell the model to translate English to French:

In [16]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()
input_df.show(truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                               input|
+----------------------------------------------------------------------------------------------------+
|translate English to French: The only reason I'm even giving this movie a 4 is because it was mad...|
|translate English to French: Awkward disaster mishmash has a team of scavengers coming across the...|
|translate English to French: Here is a fantastic concept for a film - a series of meteors crash i...|
|     translate English to French: I walked out of the cinema having suffered this film after 30 mins|
|translate English to French: A wildly uneven film where the major problem is the uneasy mix of co...|
|translate English to French: Leonard Rossiter and Frances de la Tour carry this film, not without...|
|                                                            translate En

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):

- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays 
- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function

In [17]:
def predict_batch_fn():
    import numpy as np
    import torch
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    from pyspark import TaskContext

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Initializing model on worker {TaskContext.get().partitionId()}, device {device}.")
    model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
    tokenizer = T5Tokenizer.from_pretrained("t5-small")

    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()
        inputs = tokenizer(flattened, 
                           padding=True,
                           return_tensors="pt").to(device)
        outputs = model.generate(input_ids=inputs["input_ids"],
                                 attention_mask=inputs["attention_mask"],
                                 max_length=128)
        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])
        print("predict: {}".format(len(flattened)))
        return string_outputs
    
    return predict

In [18]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=32)

In [19]:
%%time
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()



CPU times: user 10.2 ms, sys: 5.05 ms, total: 15.2 ms
Wall time: 7.41 s


                                                                                

In [20]:
%%time
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()



CPU times: user 3.93 ms, sys: 1.98 ms, total: 5.91 ms
Wall time: 4.08 s


                                                                                

In [21]:
%%time
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()



CPU times: user 3.85 ms, sys: 1.75 ms, total: 5.6 ms
Wall time: 4.08 s


                                                                                

In [22]:
preds.show(truncate=50)

+--------------------------------------------------+--------------------------------------------------+
|                                             input|                                             preds|
+--------------------------------------------------+--------------------------------------------------+
|translate English to French: The only reason I'...|La seule raison pour laquelle je donne même ce ...|
|translate English to French: Awkward disaster m...|La mishmash d’Awkward a eu une équipe de scaven...|
|translate English to French: Here is a fantasti...|Voici un concept fantastique pour un film : une...|
|translate English to French: I walked out of th...|Je me suis rendu du cinéma après avoir subi ce ...|
|translate English to French: A wildly uneven fi...|Un film extrêmement inégal où le problème majeu...|
|translate English to French: Leonard Rossiter a...|Leonard Rossiter et Frances de la Tour mettent ...|
|          translate English to French: A good cast|            

Let's try English to German:

In [23]:
input_df2 = df.select(preprocess(col("text"), "translate English to German: ").alias("input")).cache()

In [24]:
%%time
# first pass caches model/fn
preds = input_df2.withColumn("preds", generate(struct("input")))
result = preds.collect()



CPU times: user 6.02 ms, sys: 705 μs, total: 6.73 ms
Wall time: 4.24 s


                                                                                

In [25]:
%%time
preds = input_df2.withColumn("preds", generate("input"))
result = preds.collect()



CPU times: user 6.12 ms, sys: 319 μs, total: 6.43 ms
Wall time: 3.88 s


                                                                                

In [26]:
%%time
preds = input_df2.withColumn("preds", generate(col("input")))
result = preds.collect()



CPU times: user 7.03 ms, sys: 16 μs, total: 7.05 ms
Wall time: 3.9 s


                                                                                

In [27]:
preds.show(truncate=50)

+--------------------------------------------------+--------------------------------------------------+
|                                             input|                                             preds|
+--------------------------------------------------+--------------------------------------------------+
|translate English to German: The only reason I'...|Der einzige Grund, warum ich sogar diesen Film ...|
|translate English to German: Awkward disaster m...|Awkward-Katastrophenmischmash hat ein Team von ...|
|translate English to German: Here is a fantasti...|Hier ist ein fantastisches Konzept für einen Fi...|
|translate English to German: I walked out of th...|Ich ging aus dem Kino, nachdem ich diesen Film ...|
|translate English to German: A wildly uneven fi...|Ein völlig ungleicher Film, in dem das Hauptpro...|
|translate English to German: Leonard Rossiter a...|Leonard Rossiter und Frances de la Tour tragen ...|
|          translate English to German: A good cast|            

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-server.png" alt="drawing" width="700"/>

In [28]:
from functools import partial

Import the helper class from server_utils.py:

In [29]:
sc.addPyFile("server_utils.py")

from server_utils import TritonServerManager

Define the Triton Server function:

In [30]:
def triton_server(ports):
    import time
    import signal
    import numpy as np
    import torch
    from transformers import T5Tokenizer, T5ForConditionalGeneration
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton, TritonConfig
    from pyspark import TaskContext

    print(f"SERVER: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.")
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"SERVER: Using {DEVICE} device.")
    model = model.to(DEVICE)

    @batch
    def _infer_fn(**inputs):
        sentences = np.squeeze(inputs["text"]).tolist()
        print(f"SERVER: Received batch of size {len(sentences)}")
        decoded_sentences = [s.decode("utf-8") for s in sentences]
        inputs = tokenizer(decoded_sentences,
                        padding=True,
                        return_tensors="pt").to(DEVICE)
        output_ids = model.generate(input_ids=inputs["input_ids"],
                                    attention_mask=inputs["attention_mask"],
                                    max_length=128)
        outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])
        return {
            "translations": outputs,
        }

    workspace_path = f"/tmp/triton_{time.strftime('%m_%d_%M_%S')}"
    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])
    with Triton(config=triton_conf, workspace=workspace_path) as triton:
        triton.bind(
            model_name="ConditionalGeneration",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="text", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="translations", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=64,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def _stop_triton(signum, frame):
            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, _stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

#### Start Triton servers

The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:
- Find available ports for HTTP/gRPC/metrics
- Deploy a server on each node via stage-level scheduling
- Gracefully shutdown servers across nodes

In [32]:
model_name = "ConditionalGeneration"
server_manager = TritonServerManager(model_name=model_name)

In [None]:
# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}
server_manager.start_servers(triton_server)

2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)
2025-02-07 11:03:44,810 - INFO - Starting 1 servers.


                                                                                

{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}

#### Define client function

Get the hostname -> url mapping from the server manager:

In [None]:
host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url

Define the Triton inference function, which returns a predict function for batch inference through the server:

In [35]:
def triton_fn(model_name, host_to_url):
    import socket
    import numpy as np
    from pytriton.client import ModelClient

    url = host_to_url[socket.gethostname()]
    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, inference_timeout_s=240) as client:
            flattened = np.squeeze(inputs).tolist() 
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["translations"], -1)
            return result_data
        
    return infer_batch

In [39]:
generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=32)

#### Load and preprocess DataFrame

In [36]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [37]:
df = spark.read.parquet(data_path).limit(512).repartition(8)

In [38]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()

25/02/04 13:35:39 WARN CacheManager: Asked to cache already cached data.


#### Run Inference

In [40]:
%%time
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()



CPU times: user 5.09 ms, sys: 4.41 ms, total: 9.5 ms
Wall time: 4.96 s


                                                                                

In [41]:
%%time
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()



CPU times: user 5.4 ms, sys: 1.12 ms, total: 6.52 ms
Wall time: 4.41 s


                                                                                

In [42]:
%%time
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()



CPU times: user 4.59 ms, sys: 1.79 ms, total: 6.38 ms
Wall time: 4.55 s


                                                                                

In [43]:
preds.show(truncate=50)

+--------------------------------------------------+--------------------------------------------------+
|                                             input|                                             preds|
+--------------------------------------------------+--------------------------------------------------+
|translate English to French: The only reason I'...|La seule raison pour laquelle je donne même ce ...|
|translate English to French: Awkward disaster m...|La mishmash d’Awkward a eu une équipe de scaven...|
|translate English to French: Here is a fantasti...|Voici un concept fantastique pour un film : une...|
|translate English to French: I walked out of th...|Je me suis rendu du cinéma après avoir subi ce ...|
|translate English to French: A wildly uneven fi...|Un film extrêmement inégal où le problème majeu...|
|translate English to French: Leonard Rossiter a...|Leonard Rossiter et Frances de la Tour mettent ...|
|          translate English to French: A good cast|            

#### Shut down servers on each executor

In [44]:
server_manager.stop_servers()

2025-02-04 13:35:53,794 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)
2025-02-04 13:35:58,983 - INFO - Sucessfully stopped 1 servers.                 


[True]

In [45]:
if not on_databricks:  # on databricks, spark.stop() puts the cluster in a bad state
    spark.stop()