# PySpark Huggingface Inferencing
## Sentence Transformers with PyTorch

From: https://huggingface.co/sentence-transformers

In [2]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

#Sentences we want to encode. Example:
sentence = ['This framework generates embeddings for each input sentence']


#Sentences are encoded by calling model.encode()
embedding = model.encode(sentence)

  from tqdm.autonotebook import tqdm, trange


In [3]:
embedding

array([[-1.76214486e-01,  1.20601125e-01, -2.93624103e-01,
        -2.29858145e-01, -8.22926834e-02,  2.37709388e-01,
         3.39985073e-01, -7.80964196e-01,  1.18127652e-01,
         1.63373873e-01, -1.37715250e-01,  2.40282789e-01,
         4.25125629e-01,  1.72417969e-01,  1.05279535e-01,
         5.18164277e-01,  6.22218847e-02,  3.99285913e-01,
        -1.81652412e-01, -5.85578680e-01,  4.49718609e-02,
        -1.72750548e-01, -2.68443346e-01, -1.47386059e-01,
        -1.89217880e-01,  1.92150414e-01, -3.83842528e-01,
        -3.96007061e-01,  4.30648923e-01, -3.15319657e-01,
         3.65950078e-01,  6.05156757e-02,  3.57325763e-01,
         1.59736335e-01, -3.00983846e-01,  2.63250172e-01,
        -3.94311100e-01,  1.84855402e-01, -3.99549127e-01,
        -2.67889708e-01, -5.45117438e-01, -3.13405506e-02,
        -4.30644155e-01,  1.33278280e-01, -1.74793780e-01,
        -4.35465664e-01, -4.77378994e-01,  7.12554976e-02,
        -7.37000555e-02,  5.69137216e-01, -2.82579571e-0

## PySpark

## Inference using Spark DL API
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [4]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, struct
from pyspark.sql.types import ArrayType, FloatType
from pyspark.sql import SparkSession
from datasets import load_dataset

In [5]:
num_threads = 6

# Creating a local Spark session for demonstration, in case it hasn't already been created.

_config = {
    "spark.master": f"local[{num_threads}]",
    "spark.driver.host": "127.0.0.1",
    "spark.task.maxFailures": "1",
    "spark.driver.memory": "8g",
    "spark.executor.memory": "8g",
    "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
    "spark.sql.pyspark.jvmStacktrace.enabled": "true",
    "spark.sql.execution.arrow.pyspark.enabled": "true",
    "spark.python.worker.reuse": "true",
}
spark = SparkSession.builder.appName("spark-dl-example")
for key, value in _config.items():
    spark = spark.config(key, value)
spark = spark.getOrCreate()

sc = spark.sparkContext

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
24/09/25 17:05:29 WARN Utils: Your hostname, dgx2h0194.spark.sjc4.nvmetal.net resolves to a loopback address: 127.0.1.1; using 10.150.30.2 instead (on interface enp134s0f0np0)
24/09/25 17:05:29 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/25 17:05:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [6]:
# load IMDB reviews (test) dataset and write to parquet
data = load_dataset("imdb", split="test")

lines = []
for example in data:
    lines.append([example["text"].split(".")[0]])

len(lines)

df = spark.createDataFrame(lines, ['lines']).repartition(10)
df.schema

df.write.mode("overwrite").parquet("imdb_test")

                                                                                

In [7]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100).cache()

In [8]:
df.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|A ridiculous movie, a terrible editing job, worst screenplay, ridiculous acting, a story that is completely ununderst...|
|                                                        Most of this film was okay, for a sequel of a sequel of a sequel|
|                                                                                                                 I tried|
|                                             This movie attempted to make Stu Ungar's life interesting by being creative|
|After I saw this I concluded that it was most likely a chick flick; afterward I found out that Keira's mother wrote t...|
|Jeff Speakman n

In [9]:
def predict_batch_fn():
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
    def predict(inputs):
        return model.encode(inputs.tolist())
    return predict

In [10]:
encode = predict_batch_udf(predict_batch_fn,
                           return_type=ArrayType(FloatType()),
                           batch_size=10)

In [11]:
%%time
# first pass caches model/fn
embeddings = df.withColumn("encoding", encode(struct("lines")))
results = embeddings.collect()



CPU times: user 13.6 ms, sys: 474 μs, total: 14.1 ms
Wall time: 5.53 s


                                                                                

In [12]:
%%time
embeddings = df.withColumn("encoding", encode("lines"))
results = embeddings.collect()



CPU times: user 8.65 ms, sys: 4.27 ms, total: 12.9 ms
Wall time: 5.44 s


                                                                                

In [13]:
%%time
embeddings = df.withColumn("encoding", encode(col("lines")))
results = embeddings.collect()



CPU times: user 5.11 ms, sys: 5.83 ms, total: 10.9 ms
Wall time: 5.27 s


                                                                                

In [14]:
embeddings.show(truncate=60)



+------------------------------------------------------------+------------------------------------------------------------+
|                                                       lines|                                                    encoding|
+------------------------------------------------------------+------------------------------------------------------------+
|A ridiculous movie, a terrible editing job, worst screenp...|[-0.13450998, -0.53543544, 0.054044724, -0.1395307, 0.549...|
|Most of this film was okay, for a sequel of a sequel of a...|[-0.059694894, 0.13422318, -0.008580661, 0.10549253, -0.1...|
|                                                     I tried|[0.36901277, 0.09817391, 0.44426093, -0.41252792, -0.3193...|
|This movie attempted to make Stu Ungar's life interesting...|[-0.060258333, -0.15493791, -0.16713744, 0.31275272, 0.02...|
|After I saw this I concluded that it was most likely a ch...|[0.13928875, -0.20784612, -0.22824976, -0.054931596, -0.0...|
|Jeff Sp

                                                                                

### Using Triton Inference Server

Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:
```
conda create -n huggingface -c conda-forge python=3.10.0
conda activate huggingface

export PYTHONNOUSERSITE=True
pip install numpy<2 conda-pack sentencepiece sentence_transformers transformers

conda-pack  # huggingface.tar.gz
```

In [15]:
import numpy as np
import os
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, struct
from pyspark.sql.types import ArrayType, FloatType

In [16]:
%%bash
# copy custom model to expected layout for Triton
rm -rf models
mkdir -p models
cp -r models_config/hf_transformer models

# add custom execution environment
cp huggingface.tar.gz models

#### Start Triton Server on each executor

In [17]:
num_executors = 1
triton_models_dir = "{}/models".format(os.getcwd())
huggingface_cache_dir = "{}/.cache/huggingface".format(os.path.expanduser('~'))
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:24.08-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            environment=[
                "TRANSFORMERS_CACHE=/cache"
            ],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="512M",
            volumes={
                triton_models_dir: {"bind": "/models", "mode": "ro"},
                huggingface_cache_dir: {"bind": "/cache", "mode": "rw"}
            }
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)

    return [True]

nodeRDD.barrier().mapPartitions(start_triton).collect()

>>>> starting triton: a6855c050b3a                                  (0 + 1) / 1]
                                                                                

[True]

#### Run inference

In [18]:
from functools import partial
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, struct
from pyspark.sql.types import ArrayType, FloatType

In [19]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100).cache()

24/09/25 17:07:00 WARN CacheManager: Asked to cache already cached data.


In [20]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool_),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [21]:
encode = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="hf_transformer"),
                           return_type=ArrayType(FloatType()),
                           input_tensor_shapes=[[1]],
                           batch_size=100)

In [22]:
%%time
# first pass caches model/fn
embeddings = df.withColumn("encoding", encode(struct("lines")))
results = embeddings.collect()

[Stage 19:>                                                         (0 + 1) / 1]

CPU times: user 10.5 ms, sys: 762 μs, total: 11.3 ms
Wall time: 1.47 s


                                                                                

In [23]:
%%time
embeddings = df.withColumn("encoding", encode("lines"))
results = embeddings.collect()

CPU times: user 5.41 ms, sys: 0 ns, total: 5.41 ms
Wall time: 207 ms


In [24]:
%%time
embeddings = df.withColumn("encoding", encode(col("lines")))
results = embeddings.collect()

CPU times: user 4.66 ms, sys: 196 μs, total: 4.86 ms
Wall time: 198 ms


In [25]:
embeddings.show(truncate=60)

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       lines|                                                    encoding|
+------------------------------------------------------------+------------------------------------------------------------+
|A ridiculous movie, a terrible editing job, worst screenp...|[-0.13450995, -0.5354353, 0.054044887, -0.13953091, 0.549...|
|Most of this film was okay, for a sequel of a sequel of a...|[-0.05969492, 0.13422322, -0.008580784, 0.10549266, -0.15...|
|                                                     I tried|[0.36901197, 0.09817381, 0.44426036, -0.41252798, -0.3193...|
|This movie attempted to make Stu Ungar's life interesting...|[-0.060258564, -0.15493858, -0.16713741, 0.31275284, 0.02...|
|After I saw this I concluded that it was most likely a ch...|[0.13928875, -0.20784627, -0.22824982, -0.054931726, -0.0...|
|Jeff Sp

#### Stop Triton Server on each executor

In [26]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.barrier().mapPartitions(stop_triton).collect()

>>>> stopping containers: ['a6855c050b3a']
                                                                                

[True]

In [27]:
spark.stop()