# PySpark Huggingface Inferencing
### Text Classification using Pipelines with PyTorch

Based on: https://huggingface.co/docs/transformers/quicktour#pipeline-usage

In [31]:
import torch
from transformers import pipeline

In [32]:
# set device if gpu is available
device = 0 if torch.cuda.is_available() else -1

In [33]:
classifier = pipeline("sentiment-analysis", device=device)

No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.


In [34]:
classifier(("We are very happy to show you the 🤗 Transformers library."))

[{'label': 'POSITIVE', 'score': 0.9997795224189758}]

In [35]:
results = classifier(["We are very happy to show you the 🤗 Transformers library.", "We hope you don't hate it."])
for result in results:
    print(f"label: {result['label']}, with score: {round(result['score'], 4)}")

label: POSITIVE, with score: 0.9998
label: NEGATIVE, with score: 0.5309


#### Use another model and tokenizer in the pipeline

In [36]:
model_name = "nlptown/bert-base-multilingual-uncased-sentiment"

In [37]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [38]:
classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=device)
classifier("Nous sommes très heureux de vous présenter la bibliothèque 🤗 Transformers.")

[{'label': '5 stars', 'score': 0.7272651791572571}]

## Inference using Spark DL API

In [39]:
import pandas as pd
from pyspark.sql.functions import col, struct, pandas_udf
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import FloatType, StringType, StructField, StructType
from pyspark.sql import SparkSession

In [40]:
num_threads = 6

# Creating a local Spark session for demonstration, in case it hasn't already been created.

_config = {
    "spark.master": f"local[{num_threads}]",
    "spark.driver.host": "127.0.0.1",
    "spark.task.maxFailures": "1",
    "spark.driver.memory": "8g",
    "spark.executor.memory": "8g",
    "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
    "spark.sql.pyspark.jvmStacktrace.enabled": "true",
    "spark.sql.execution.arrow.pyspark.enabled": "true",
    "spark.python.worker.reuse": "true",
}
spark = SparkSession.builder.appName("spark-dl-example")
for key, value in _config.items():
    spark = spark.config(key, value)
spark = spark.getOrCreate()

sc = spark.sparkContext

In [41]:
from datasets import load_dataset

# Load the IMDB dataset
data = load_dataset("imdb", split="test")

lines = []
for example in data:
    # first sentence only
    lines.append([example["text"]])

len(lines)

df = spark.createDataFrame(lines, ['lines']).repartition(10).cache()

In [42]:
df.write.mode("overwrite").parquet("imdb_test")

24/09/25 17:23:32 WARN TaskSetManager: Stage 0 contains a task of very large size (5123 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

In [43]:
# only use first sentence of IMDB reviews
@pandas_udf("string")
def first_sentence(text: pd.Series) -> pd.Series:
    return pd.Series([s.split(".")[0] for s in text])

df = spark.read.parquet("imdb_test").withColumn("sentence", first_sentence(col("lines"))).select("sentence").limit(100).cache()
df.show(truncate=80)

+--------------------------------------------------------------------------------+
|                                                                        sentence|
+--------------------------------------------------------------------------------+
|In following Dylan Moran's star from the charming misanthrope bookstore owner...|
|Here in Australia Nights in Rodanthe is being promoted in the same class as t...|
|           The Tender Hook, or, Who Killed The Australian Film Industry? Case No|
|The only reason I'm even giving this movie a 4 is because it was made in to a...|
|Hooray for Title Misspellings! After reading reviews and contemplating, my gi...|
|                        This movie makes you wish imdb would let you vote a zero|
|To me this film is just a very very lame teen party movie with all the normal...|
|                                                                         I tried|
|Awkward disaster mishmash has a team of scavengers coming across the overturn...|
|Acc

In [44]:
def predict_batch_fn():
    import torch
    from transformers import pipeline
    
    device = 0 if torch.cuda.is_available() else -1
    pipe = pipeline("sentiment-analysis", device=device)
    def predict(inputs):
        return pipe(inputs.tolist())
    return predict

In [45]:
classify = predict_batch_udf(predict_batch_fn,
                             return_type=StructType([
                                 StructField("label", StringType(), True),
                                 StructField("score", FloatType(), True)
                             ]),
                             batch_size=10)

In [46]:
%%time
# note: expanding the "struct" return_type to top-level columns
preds = df.withColumn("preds", classify(struct("sentence"))).select("sentence", "preds.*")
results = preds.collect()

No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.


CPU times: user 1.26 ms, sys: 9.57 ms, total: 10.8 ms
Wall time: 5.05 s


                                                                                

In [47]:
%%time
# note: expanding the "struct" return_type to top-level columns
preds = df.withColumn("preds", classify("sentence")).select("sentence", "preds.*")
results = preds.collect()

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


CPU times: user 1.81 ms, sys: 4.6 ms, total: 6.42 ms
Wall time: 756 ms


                                                                                

In [48]:
%%time
# note: expanding the "struct" return_type to top-level columns
preds = df.withColumn("preds", classify(col("sentence"))).select("sentence", "preds.*")
results = preds.collect()

CPU times: user 4.31 ms, sys: 0 ns, total: 4.31 ms
Wall time: 711 ms


                                                                                

In [49]:
preds.show(truncate=80)

+--------------------------------------------------------------------------------+--------+----------+
|                                                                        sentence|   label|     score|
+--------------------------------------------------------------------------------+--------+----------+
|In following Dylan Moran's star from the charming misanthrope bookstore owner...|NEGATIVE|0.94646263|
|Here in Australia Nights in Rodanthe is being promoted in the same class as t...|NEGATIVE| 0.9811111|
|           The Tender Hook, or, Who Killed The Australian Film Industry? Case No|NEGATIVE|0.98526716|
|The only reason I'm even giving this movie a 4 is because it was made in to a...|NEGATIVE| 0.9995627|
|Hooray for Title Misspellings! After reading reviews and contemplating, my gi...|NEGATIVE|0.99946946|
|                        This movie makes you wish imdb would let you vote a zero|NEGATIVE| 0.9981305|
|To me this film is just a very very lame teen party movie with all the n

### Using Triton Inference Server

Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:
```
conda create -n huggingface -c conda-forge python=3.10.0
conda activate huggingface

export PYTHONNOUSERSITE=True
pip install numpy==1.26.4 conda-pack sentencepiece sentence_transformers transformers

conda-pack  # huggingface.tar.gz
```

In [50]:
import numpy as np
import pandas as pd
import os
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, struct, pandas_udf
from pyspark.sql.types import FloatType, StringType, StructField, StructType

In [51]:
%%bash
# copy custom model to expected layout for Triton
rm -rf models
mkdir -p models
cp -r models_config/hf_pipeline models

# add custom execution environment
cp huggingface.tar.gz models

#### Start Triton Server on each executor

In [52]:
num_executors = 1
triton_models_dir = "{}/models".format(os.getcwd())
huggingface_cache_dir = "{}/.cache/huggingface".format(os.path.expanduser('~'))
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:24.08-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            environment=[
                "TRANSFORMERS_CACHE=/cache"
            ],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="256M",
            volumes={
                triton_models_dir: {"bind": "/models", "mode": "ro"},
                huggingface_cache_dir: {"bind": "/cache", "mode": "rw"}
            }
        )
        print(">>>> starting triton: {}".format(container.short_id))
        # wait for triton to be running
        time.sleep(15)
        
        client = grpcclient.InferenceServerClient("localhost:8001")
        
        elapsed = 0
        timeout = 120
        ready = False
        while not ready and elapsed < timeout:
            try:
                time.sleep(5)
                elapsed += 5
                ready = client.is_server_ready()
            except Exception as e:
                pass

    return [True]

nodeRDD.barrier().mapPartitions(start_triton).collect()

>>>> starting triton: 051f90cdbb4f                                  (0 + 1) / 1]
                                                                                

[True]

#### Run inference

In [53]:
# only use first sentence of IMDB reviews
@pandas_udf("string")
def first_sentence(text: pd.Series) -> pd.Series:
    return pd.Series([s.split(".")[0] for s in text])

df = spark.read.parquet("imdb_test").withColumn("sentence", first_sentence(col("lines"))).select("sentence").limit(1000)

In [54]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool_),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [55]:
from functools import partial

classify = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="hf_pipeline"),
                             return_type=StructType([
                                 StructField("label", StringType(), True),
                                 StructField("score", FloatType(), True)
                             ]),
                             input_tensor_shapes=[[1]],
                             batch_size=100)

In [56]:
%%time
# first pass caches model/fn
# note: expanding the "struct" return_type to top-level columns
preds = df.withColumn("preds", classify(struct("sentence"))).select("sentence", "preds.*")
results = preds.collect()

[Stage 20:>                                                         (0 + 1) / 1]

CPU times: user 7.08 ms, sys: 5.09 ms, total: 12.2 ms
Wall time: 6.01 s


                                                                                

In [57]:
%%time
# note: expanding the "struct" return_type to top-level columns
preds = df.withColumn("preds", classify("sentence")).select("sentence", "preds.*")
results = preds.collect()

[Stage 21:>                                                         (0 + 1) / 1]

CPU times: user 4.46 ms, sys: 6.75 ms, total: 11.2 ms
Wall time: 5.78 s


                                                                                

In [58]:
%%time
# note: expanding the "struct" return_type to top-level columns
preds = df.withColumn("preds", classify(col("sentence"))).select("sentence", "preds.*")
results = preds.collect()

[Stage 22:>                                                         (0 + 1) / 1]

CPU times: user 8.91 ms, sys: 1.55 ms, total: 10.5 ms
Wall time: 5.73 s


                                                                                

In [59]:
preds.show(truncate=False)

[Stage 23:>                                                         (0 + 1) / 1]

+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+----------+
|sentence                                                                                                                                                                                                                                                                                                |label   |score     |
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+----------+
|In following Dylan Moran's star from the c

                                                                                

#### Stop Triton Server on each executor

In [60]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.barrier().mapPartitions(stop_triton).collect()

>>>> stopping containers: ['051f90cdbb4f']
                                                                                

[True]

In [61]:
spark.stop()