# PySpark Huggingface Inferencing
## Conditional generation with PyTorch

From: https://huggingface.co/docs/transformers/model_doc/t5

In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

max_source_length = 512
max_target_length = 128

task_prefix = "translate English to German: "

lines = [
    "The house is wonderful",
    "Welcome to NYC",
    "HuggingFace is a company"
]

input_sequences = [task_prefix + l for l in lines]

  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [2]:
input_ids = tokenizer(input_sequences, 
                      padding="longest", 
                      max_length=max_source_length,
                      truncation=True,
                      return_tensors="pt").input_ids

outputs = model.generate(input_ids)



In [3]:
[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

['Das Haus ist wunderbar',
 'Willkommen in NYC',
 'HuggingFace ist ein Unternehmen']

In [4]:
model.framework

'pt'

## PySpark

In [5]:
import os
from pathlib import Path
from datasets import load_dataset

In [6]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession

In [7]:
num_threads = 6

# Creating a local Spark session for demonstration, in case it hasn't already been created.

_config = {
    "spark.master": f"local[{num_threads}]",
    "spark.driver.host": "127.0.0.1",
    "spark.task.maxFailures": "1",
    "spark.driver.memory": "8g",
    "spark.executor.memory": "8g",
    "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
    "spark.sql.pyspark.jvmStacktrace.enabled": "true",
    "spark.sql.execution.arrow.pyspark.enabled": "true",
    "spark.python.worker.reuse": "true",
}
spark = SparkSession.builder.appName("spark-dl-example")
for key, value in _config.items():
    spark = spark.config(key, value)
spark = spark.getOrCreate()

sc = spark.sparkContext

24/09/25 16:39:14 WARN Utils: Your hostname, dgx2h0194.spark.sjc4.nvmetal.net resolves to a loopback address: 127.0.1.1; using 10.150.30.2 instead (on interface enp134s0f0np0)
24/09/25 16:39:14 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/25 16:39:15 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [8]:
# load IMDB reviews (test) dataset
data = load_dataset("imdb", split="test")

In [9]:
lines = []
for example in data:
    lines.append([example["text"].split(".")[0]])

len(lines)

25000

### Create PySpark DataFrame

In [10]:
df = spark.createDataFrame(lines, ['lines']).repartition(10)
df.schema

StructType([StructField('lines', StringType(), True)])

In [11]:
df.take(1)

                                                                                

[Row(lines="i do not understand at all why this movie received such good grades from critics - - i've seen tens of documentaries (on TV) about the wine world which were much much better when (if) you watch it, please think of two very annoying aspects of mondovino : first, the filming is just awful and terrible and upsetting : to me, it looked like the guy behind the camera just received the material and was playing with it : plenty of zooms (for no purpose other than pushing the button in/out) for instance - - i almost stopped to watch it because of that ! secondly, the interviewer (the director i think) is not really relevant : he looks like and ask questions like a boy scout, not like a journalist, even if the general idea and themes would have been interesting, too bad conclusion: overrated documentary, maybe only for guys who do not know nothing about wine => not recommended at all (2/10)")]

### Save the test dataset as parquet files

In [12]:
df.write.mode("overwrite").parquet("imdb_test")

                                                                                

### Check arrow memory configuration

In [13]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "512")
# This line will fail if the vectorized reader runs out of memory
assert len(df.head()) > 0, "`df` should not be empty"

## Inference using Spark DL API (PyTorch)
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [14]:
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [15]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [16]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)
df.show(truncate=120)
df.count()

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|A ridiculous movie, a terrible editing job, worst screenplay, ridiculous acting, a story that is completely ununderst...|
|                                                        Most of this film was okay, for a sequel of a sequel of a sequel|
|                                                                                                                 I tried|
|                                             This movie attempted to make Stu Ungar's life interesting by being creative|
|After I saw this I concluded that it was most likely a chick flick; afterward I found out that Keira's mother wrote t...|
|Jeff Speakman n

100

In [17]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100).cache()

In [18]:
df1.count()

                                                                                

100

In [19]:
df1.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|Translate English to German: A ridiculous movie, a terrible editing job, worst screenplay, ridiculous acting, a story...|
|                           Translate English to German: Most of this film was okay, for a sequel of a sequel of a sequel|
|                                                                                    Translate English to German: I tried|
|                Translate English to German: This movie attempted to make Stu Ungar's life interesting by being creative|
|Translate English to German: After I saw this I concluded that it was most likely a chick flick; afterward I found ou...|
|Translate Engli

In [20]:
def predict_batch_fn():
    import numpy as np
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    tokenizer = T5Tokenizer.from_pretrained("t5-small")

    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()   # convert 2d numpy array of string into flattened python list
        input_ids = tokenizer(flattened, 
                              padding="longest", 
                              max_length=128,
                              return_tensors="pt").input_ids
        output_ids = model.generate(input_ids)
        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in output_ids])
        print("predict: {}".format(len(flattened)))
        return string_outputs
    
    return predict

In [21]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=10)

In [22]:
%%time
# first pass caches model/fn
preds = df1.withColumn("preds", generate(struct("input")))
results = preds.collect()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 19.1 ms, sys: 0 ns, total: 19.1 ms
Wall time: 22.2 s


predict: 10
                                                                                

In [23]:
%%time
preds = df1.withColumn("preds", generate("input"))
results = preds.collect()

predict: 10                                                         (0 + 1) / 1]
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 7.77 ms, sys: 8.69 ms, total: 16.5 ms
Wall time: 18 s


predict: 10
                                                                                

In [24]:
%%time
preds = df1.withColumn("preds", generate(col("input")))
results = preds.collect()

predict: 10                                                         (0 + 1) / 1]
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 10.9 ms, sys: 2.94 ms, total: 13.8 ms
Wall time: 17.8 s


predict: 10
                                                                                

In [25]:
preds.show(truncate=60)

predict: 10                                                         (0 + 1) / 1]
predict: 10


+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to German: A ridiculous movie, a terrib...|Ein lächerlicher Film, eine schreckliche Bearbeitung, sch...|
|Translate English to German: Most of this film was okay, ...|Der größte Teil dieses Films war okay, für eine Fortsetzu...|
|                        Translate English to German: I tried|                   Ich habe versucht, Englisch zu übersetzen|
|Translate English to German: This movie attempted to make...|Dieser Film versuchte, das Leben von Stu Ungar interessan...|
|Translate English to German: After I saw this I concluded...|Nach meiner Anzeige kam ich zu dem Schluss, dass es höchs...|
|Transla

predict: 121
                                                                                

In [26]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100).cache()

In [27]:
df2.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|Translate English to French: A ridiculous movie, a terrible editing job, worst screenplay, ridiculous acting, a story...|
|                           Translate English to French: Most of this film was okay, for a sequel of a sequel of a sequel|
|                                                                                    Translate English to French: I tried|
|                Translate English to French: This movie attempted to make Stu Ungar's life interesting by being creative|
|Translate English to French: After I saw this I concluded that it was most likely a chick flick; afterward I found ou...|
|Translate Engli

                                                                                

In [28]:
%%time
# first pass caches model/fn
preds = df2.withColumn("preds", generate(struct("input")))
result = preds.collect()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 16.3 ms, sys: 775 μs, total: 17.1 ms
Wall time: 22.4 s


predict: 10
                                                                                

In [29]:
%%time
preds = df2.withColumn("preds", generate("input"))
result = preds.collect()

predict: 10                                                         (0 + 1) / 1]
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 13.3 ms, sys: 1.09 ms, total: 14.4 ms
Wall time: 18.3 s


predict: 10
                                                                                

In [30]:
%%time
preds = df2.withColumn("preds", generate(col("input")))
result = preds.collect()

predict: 10                                                         (0 + 1) / 1]
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 15.8 ms, sys: 0 ns, total: 15.8 ms
Wall time: 18.1 s


predict: 10
                                                                                

In [31]:
preds.show(truncate=60)

predict: 10                                                         (0 + 1) / 1]
predict: 10


+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to French: A ridiculous movie, a terrib...|Un film ridicule, un terrible travail de rédaction, le pi...|
|Translate English to French: Most of this film was okay, ...|La plupart de ce film était en bonne et due forme, pour u...|
|                        Translate English to French: I tried|                                                 J'ai essayé|
|Translate English to French: This movie attempted to make...|Ce film tentait de rendre la vie de Stu Ungar intéressant...|
|Translate English to French: After I saw this I concluded...|Après avoir vu ce film, j'ai conclu qu'il était très prob...|
|Transla

predict: 121
                                                                                

### Using Triton Inference Server

Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:
```
conda create -n huggingface -c conda-forge python=3.10.0
conda activate huggingface

export PYTHONNOUSERSITE=True
pip install numpy<2 conda-pack sentencepiece sentence_transformers transformers

conda-pack  # huggingface.tar.gz
```

In [32]:
import os

In [33]:
%%bash
# copy custom model to expected layout for Triton
rm -rf models
mkdir -p models
cp -r models_config/hf_generation models

# add custom execution environment
cp huggingface.tar.gz models

#### Start Triton Server on each executor

In [34]:
num_executors = 1
triton_models_dir = "{}/models".format(os.getcwd())
huggingface_cache_dir = "{}/.cache/huggingface".format(os.path.expanduser('~'))
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:24.08-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            environment=[
                "TRANSFORMERS_CACHE=/cache"
            ],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="1G",
            volumes={
                triton_models_dir: {"bind": "/models", "mode": "ro"},
                huggingface_cache_dir: {"bind": "/cache", "mode": "rw"}
            }
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)

    return [True]

nodeRDD.barrier().mapPartitions(start_triton).collect()

>>>> starting triton: eeba2c4778b2                                  (0 + 1) / 1]
                                                                                

[True]

#### Run inference

In [35]:
import pandas as pd
from functools import partial
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [36]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100).cache()

In [37]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [38]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100)

In [39]:
df1.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|Translate English to German: A ridiculous movie, a terrible editing job, worst screenplay, ridiculous acting, a story...|
|                           Translate English to German: Most of this film was okay, for a sequel of a sequel of a sequel|
|                                                                                    Translate English to German: I tried|
|                Translate English to German: This movie attempted to make Stu Ungar's life interesting by being creative|
|Translate English to German: After I saw this I concluded that it was most likely a chick flick; afterward I found ou...|
|Translate Engli

In [40]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool8),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [41]:
generate = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="hf_generation"),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=100)

In [42]:
%%time
# first pass caches model/fn
preds = df1.withColumn("preds", generate(struct("input")))
results = preds.collect()



CPU times: user 7.46 ms, sys: 1.34 ms, total: 8.8 ms
Wall time: 8.96 s


                                                                                

In [43]:
%%time
preds = df1.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 47:>                                                         (0 + 1) / 1]

CPU times: user 6.27 ms, sys: 0 ns, total: 6.27 ms
Wall time: 7.82 s


                                                                                

In [44]:
%%time
preds = df1.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 49:>                                                         (0 + 1) / 1]

CPU times: user 6.28 ms, sys: 0 ns, total: 6.28 ms
Wall time: 8.01 s


                                                                                

In [45]:
preds.show(truncate=60)

[Stage 51:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to German: A ridiculous movie, a terrib...|Ein lächerlicher Film, eine schreckliche Bearbeitung, sch...|
|Translate English to German: Most of this film was okay, ...|Der größte Teil dieses Films war okay, für eine Fortsetzu...|
|                        Translate English to German: I tried|                   Ich habe versucht, Englisch zu übersetzen|
|Translate English to German: This movie attempted to make...|Dieser Film versuchte, das Leben von Stu Ungar interessan...|
|Translate English to German: After I saw this I concluded...|Nach meiner Anzeige kam ich zu dem Schluss, dass es höchs...|
|Transla

                                                                                

In [46]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100).cache()

24/09/25 16:44:23 WARN CacheManager: Asked to cache already cached data.


In [47]:
df2.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|Translate English to French: A ridiculous movie, a terrible editing job, worst screenplay, ridiculous acting, a story...|
|                           Translate English to French: Most of this film was okay, for a sequel of a sequel of a sequel|
|                                                                                    Translate English to French: I tried|
|                Translate English to French: This movie attempted to make Stu Ungar's life interesting by being creative|
|Translate English to French: After I saw this I concluded that it was most likely a chick flick; afterward I found ou...|
|Translate Engli

In [48]:
%%time
preds = df2.withColumn("preds", generate(struct("input")))
results = preds.collect()



CPU times: user 7.26 ms, sys: 582 μs, total: 7.84 ms
Wall time: 7.58 s


                                                                                

In [49]:
%%time
preds = df2.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 57:>                                                         (0 + 1) / 1]

CPU times: user 1.83 ms, sys: 5.48 ms, total: 7.3 ms
Wall time: 6.82 s


                                                                                

In [50]:
%%time
preds = df2.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 59:>                                                         (0 + 1) / 1]

CPU times: user 8.6 ms, sys: 162 μs, total: 8.76 ms
Wall time: 20.6 s


                                                                                

In [51]:
preds.show(truncate=60)

[Stage 61:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to French: A ridiculous movie, a terrib...|Un film ridicule, un terrible travail de rédaction, le pi...|
|Translate English to French: Most of this film was okay, ...|La plupart de ce film était en bonne et due forme, pour u...|
|                        Translate English to French: I tried|                                                 J'ai essayé|
|Translate English to French: This movie attempted to make...|Ce film tentait de rendre la vie de Stu Ungar intéressant...|
|Translate English to French: After I saw this I concluded...|Après avoir vu ce film, j'ai conclu qu'il était très prob...|
|Transla

                                                                                

#### Stop Triton Server on each executor

In [52]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.barrier().mapPartitions(stop_triton).collect()

>>>> stopping containers: ['eeba2c4778b2']
                                                                                

[True]

In [53]:
spark.stop()