<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark Huggingface Inferencing
### Sentiment Analysis using Pipelines with PyTorch

In this notebook, we demonstrate distributed inference with Huggingface Pipelines to perform sentiment analysis.  
From: https://huggingface.co/docs/transformers/quicktour#pipeline-usage

In [1]:
import torch
from transformers import pipeline

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
classifier = pipeline("sentiment-analysis", device=device)

No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision 714eb0f (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.


In [4]:
classifier(("We are very happy to show you the 🤗 Transformers library."))

[{'label': 'POSITIVE', 'score': 0.9997795224189758}]

In [5]:
results = classifier(["We are very happy to show you the 🤗 Transformers library.", "We hope you don't hate it."])
for result in results:
    print(f"label: {result['label']}, with score: {round(result['score'], 4)}")

label: POSITIVE, with score: 0.9998
label: NEGATIVE, with score: 0.5309


Let's try a different model and tokenizer in the pipeline.

In [6]:
model_name = "nlptown/bert-base-multilingual-uncased-sentiment"

In [7]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.safetensors:   0%|          | 0.00/669M [00:00<?, ?B/s]

In [8]:
classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=device)
classifier("Nous sommes très heureux de vous présenter la bibliothèque 🤗 Transformers.")

[{'label': '5 stars', 'score': 0.7272652983665466}]

## PySpark

In [9]:
from pyspark.sql.functions import col, struct, pandas_udf
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark import SparkConf

In [10]:
import os
import json
import pandas as pd
import datasets
from datasets import load_dataset
datasets.disable_progress_bars()

Check the cluster environment to handle any platform-specific Spark configurations.

In [11]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

#### Create Spark Session

For local standalone clusters, we'll connect to the cluster and create the Spark Session.  
For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).

In [12]:
conf = SparkConf()

if 'spark' not in globals():
    if on_standalone:
        import socket
        conda_env = os.environ.get("CONDA_PREFIX")
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
        # Point PyTriton to correct libpython3.11.so:
        conf.set("spark.executorEnv.LD_LIBRARY_PATH", f"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH")
    elif on_dataproc:
        # Point PyTriton to correct libpython3.11.so:
        conda_lib_path="/opt/conda/miniconda3/lib"
        conf.set("spark.executorEnv.LD_LIBRARY_PATH", f"{conda_lib_path}:$LD_LIBRARY_PATH") 

    conf.set("spark.executor.cores", "8")
    conf.set("spark.task.resource.gpu.amount", "0.125")
    conf.set("spark.executor.resource.gpu.amount", "1")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")

conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")
spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

25/01/06 18:29:40 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
25/01/06 18:29:40 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/06 18:29:40 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [13]:
dataset = load_dataset("imdb", split="test")
dataset = dataset.to_pandas().drop(columns="label")

#### Create PySpark DataFrame

In [14]:
df = spark.createDataFrame(dataset).repartition(8)
df.schema

StructType([StructField('text', StringType(), True)])

In [15]:
df.count()

25000

In [16]:
df.take(1)

25/01/06 18:29:47 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.


[Row(text="Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.<br /><br />The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.<br /><br />The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.<br /><br />I really got nothing much left to say except, give us back CKY2K, cause Bam suck..<br /><br />I enjoy watching Steve-o, Knoxville etc. a thousand times more.")]

In [17]:
data_path = "spark-dl-datasets/imdb_test"
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-datasets")
    data_path = "dbfs:/FileStore/" + data_path

df.write.mode("overwrite").parquet(data_path)

25/01/06 18:29:47 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.


#### Load and preprocess DataFrame

Define our preprocess function. We'll take the first sentence from each sample as our input for sentiment analysis.

In [18]:
@pandas_udf("string")
def preprocess(text: pd.Series) -> pd.Series:
    return pd.Series([s.split(".")[0] for s in text])

In [19]:
df = spark.read.parquet(data_path).limit(256).repartition(8)
df = df.select(preprocess(col("text")).alias("input")).cache()
df.show(truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                               input|
+----------------------------------------------------------------------------------------------------+
|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|
|                          There were two things I hated about WASTED : The directing and the script |
|                                I'm rather surprised that anybody found this film touching or moving|
|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|
|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|
|                                                                     This movie has been done before|
|[ as a new resolution for this year 2005, i decide to write a comment fo

                                                                                

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):

- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays 
- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function

In [20]:
def predict_batch_fn():
    import torch
    from transformers import pipeline
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pipe = pipeline("sentiment-analysis", device=device)
    def predict(inputs):
        return pipe(inputs.tolist())
    return predict

In [21]:
classify = predict_batch_udf(predict_batch_fn,
                             return_type=StructType([
                                 StructField("label", StringType(), True),
                                 StructField("score", FloatType(), True)
                             ]),
                             batch_size=32)

In [22]:
%%time
# first pass caches model/fn
# note: expanding the "struct" return_type to top-level columns
preds = df.withColumn("preds", classify(struct("input"))).select("input", "preds.*")
results = preds.collect()



CPU times: user 14.8 ms, sys: 4.23 ms, total: 19 ms
Wall time: 3.15 s


                                                                                

In [23]:
%%time
preds = df.withColumn("preds", classify("input")).select("input", "preds.*")
results = preds.collect()

CPU times: user 2.59 ms, sys: 2.33 ms, total: 4.91 ms
Wall time: 393 ms


In [24]:
%%time
preds = df.withColumn("preds", classify(col("input"))).select("input", "preds.*")
results = preds.collect()

CPU times: user 2.65 ms, sys: 2.41 ms, total: 5.06 ms
Wall time: 398 ms


In [25]:
preds.show(truncate=80)

+--------------------------------------------------------------------------------+--------+----------+
|                                                                           input|   label|     score|
+--------------------------------------------------------------------------------+--------+----------+
|Doesn't anyone bother to check where this kind of sludge comes from before bl...|NEGATIVE| 0.9984042|
|      There were two things I hated about WASTED : The directing and the script |NEGATIVE| 0.9979019|
|            I'm rather surprised that anybody found this film touching or moving|POSITIVE| 0.8392794|
|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an ac...|NEGATIVE|0.99726933|
|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw the...|POSITIVE|0.98212516|
|                                                 This movie has been done before|NEGATIVE|0.94194806|
|[ as a new resolution for this year 2005, i decide to write a comment fo

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-pytriton.png" alt="drawing" width="700"/>

In [26]:
from functools import partial

In [27]:
def triton_server(ports):
    import time
    import signal
    import numpy as np
    import torch
    from transformers import pipeline
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton, TritonConfig
    from pyspark import TaskContext

    print(f"SERVER: Initializing pipeline on worker {TaskContext.get().partitionId()}.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pipe = pipeline("sentiment-analysis", device=device)
    print(f"SERVER: Using {device} device.")

    @batch
    def _infer_fn(**inputs):
        sentences = np.squeeze(inputs["text"]).tolist()
        print(f"SERVER: Received batch of size {len(sentences)}")
        decoded_sentences = [s.decode("utf-8") for s in sentences]
        return {
            "outputs": np.array([[json.dumps(o)] for o in pipe(decoded_sentences)])
        }

    workspace_path = f"triton_workspace_{time.strftime('%Y%m%d_%H%M%S')}"
    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])
    with Triton(config=triton_conf, workspace=workspace_path) as triton:
        triton.bind(
            model_name="SentimentAnalysis",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="text", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="outputs", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=64,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def stop_triton(signum, frame):
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

def start_triton(ports, model_name):
    import socket
    from pathlib import Path
    from multiprocessing import Process
    from pytriton.client import ModelClient

    os.chdir(Path.home())

    hostname = socket.gethostname()
    process = Process(target=triton_server, args=(ports,))
    process.start()

    client = ModelClient(f"http://localhost:{ports[0]}", model_name)
    patience = 8
    while patience > 0:
        try:
            client.wait_for_server(5)
            return [(hostname, process.pid)]
        except Exception:
            print("Waiting for server to be ready...")
            patience -= 1

    emsg = "Failure: client waited too long for server startup."
    print(emsg)
    return [(hostname, emsg)]

#### Start Triton servers

To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU.  

In [28]:
def _use_stage_level_scheduling(spark, rdd):

    if spark.version < "3.4.0":
        raise Exception("Stage-level scheduling is not supported in Spark < 3.4.0")

    executor_cores = spark.conf.get("spark.executor.cores")
    assert executor_cores is not None, "spark.executor.cores is not set"
    executor_gpus = spark.conf.get("spark.executor.resource.gpu.amount")
    assert executor_gpus is not None and int(executor_gpus) <= 1, "spark.executor.resource.gpu.amount must be set and <= 1"

    from pyspark.resource.profile import ResourceProfileBuilder
    from pyspark.resource.requests import TaskResourceRequests

    spark_plugins = spark.conf.get("spark.plugins", " ")
    assert spark_plugins is not None
    spark_rapids_sql_enabled = spark.conf.get("spark.rapids.sql.enabled", "true")
    assert spark_rapids_sql_enabled is not None

    task_cores = (
        int(executor_cores)
        if "com.nvidia.spark.SQLPlugin" in spark_plugins
        and "true" == spark_rapids_sql_enabled.lower()
        else (int(executor_cores) // 2) + 1
    )

    task_gpus = 1.0
    treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
    rp = ResourceProfileBuilder().require(treqs).build
    print(f"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})")

    return rdd.withResources(rp)

**Specify the number of nodes in the cluster.**  
Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. 

In [29]:
# Change based on cluster setup
num_nodes = 1 if on_standalone else 4

In [None]:
sc = spark.sparkContext
nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)

Reqesting stage-level resources: (cores=5, gpu=1.0)


Triton occupies ports for HTTP requests, GRPC requests, and the metrics service.

In [31]:
def find_ports():
    import psutil
    
    ports = []
    conns = [conn.laddr.port for conn in psutil.net_connections(kind="inet")]
    i = 7000
    while len(ports) < 3:
        if i not in conns:
            ports.append(i)
        i += 1
    
    return ports

In [32]:
model_name = "SentimentAnalysis"

ports = find_ports()
assert len(ports) == 3
print(f"Using ports {ports}")

Using ports [7000, 7001, 7002]


In [33]:
pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(ports, model_name)).collectAsMap()
print("Triton Server PIDs:\n", json.dumps(pids, indent=4))

[Stage 28:>                                                         (0 + 1) / 1]

Triton Server PIDs:
 {
    "cb4ae00-lcedt": 2571652
}


                                                                                

#### Define client function

In [34]:
url = f"http://localhost:{ports[0]}"

In [35]:
def triton_fn(url, model_name, init_timeout_s):
    import numpy as np
    from pytriton.client import ModelClient

    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, init_timeout_s=init_timeout_s) as client:
            flattened = np.squeeze(inputs).tolist()
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["outputs"], -1)
            return [json.loads(o) for o in result_data]
        
    return infer_batch

#### Load and preprocess DataFrame

In [36]:
@pandas_udf("string")
def preprocess(text: pd.Series) -> pd.Series:
    return pd.Series([s.split(".")[0] for s in text])

In [37]:
df = spark.read.parquet(data_path).limit(256).repartition(8)
df = df.select(preprocess(col("text")).alias("input")).cache()

25/01/06 18:29:56 WARN CacheManager: Asked to cache already cached data.


#### Run Inference

In [38]:
classify = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name, init_timeout_s=600),
                             return_type=StructType([
                                 StructField("label", StringType(), True),
                                 StructField("score", FloatType(), True)
                             ]),
                             input_tensor_shapes=[[1]],
                             batch_size=32)

In [39]:
%%time
# first pass caches model/fn
# note: expanding the "struct" return_type to top-level columns
preds = df.withColumn("preds", classify(struct("input"))).select("input", "preds.*")
results = preds.collect()



CPU times: user 16.7 ms, sys: 3.77 ms, total: 20.4 ms
Wall time: 2.58 s


                                                                                

In [40]:
%%time
preds = df.withColumn("preds", classify("input")).select("input", "preds.*")
results = preds.collect()

CPU times: user 2.77 ms, sys: 0 ns, total: 2.77 ms
Wall time: 462 ms


In [41]:
%%time
preds = df.withColumn("preds", classify(col("input"))).select("input", "preds.*")
results = preds.collect()

CPU times: user 2.51 ms, sys: 2.71 ms, total: 5.22 ms
Wall time: 461 ms


In [42]:
preds.show(truncate=70)

+----------------------------------------------------------------------+--------+----------+
|                                                                 input|   label|     score|
+----------------------------------------------------------------------+--------+----------+
|Doesn't anyone bother to check where this kind of sludge comes from...|NEGATIVE| 0.9984042|
|There were two things I hated about WASTED : The directing and the ...|NEGATIVE| 0.9979019|
|  I'm rather surprised that anybody found this film touching or moving|POSITIVE| 0.8392794|
|Cultural Vandalism Is the new Hallmark production of Gulliver's Tra...|NEGATIVE|0.99726933|
|I was at Wrestlemania VI in Toronto as a 10 year old, and the event...|POSITIVE|0.98212516|
|                                       This movie has been done before|NEGATIVE|0.94194806|
|[ as a new resolution for this year 2005, i decide to write a comme...|NEGATIVE|0.99678314|
|This movie is over hyped!! I am sad to say that I manage to watch t..

#### Shut down server on each executor

In [43]:
def stop_triton(pids):
    import os
    import socket
    import signal
    import time 
    
    hostname = socket.gethostname()
    pid = pids.get(hostname, None)
    assert pid is not None, f"Could not find pid for {hostname}"
    
    for _ in range(5):
        try:
            os.kill(pid, signal.SIGTERM)
        except OSError:
            return [True]
        time.sleep(5)

    return [False]

shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)
shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()

Reqesting stage-level resources: (cores=5, gpu=1.0)


                                                                                

[True]

In [44]:
spark.stop()