# PySpark Huggingface Inferencing
## Conditional generation with PyTorch

From: https://huggingface.co/docs/transformers/model_doc/t5

In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

max_source_length = 512
max_target_length = 128

task_prefix = "translate English to German: "

lines = [
    "The house is wonderful",
    "Welcome to NYC",
    "HuggingFace is a company"
]

input_sequences = [task_prefix + l for l in lines]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:
input_ids = tokenizer(input_sequences, 
                      padding="longest", 
                      max_length=max_source_length,
                      truncation=True,
                      return_tensors="pt").input_ids

outputs = model.generate(input_ids)



In [5]:
[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

['Das Haus ist wunderbar',
 'Willkommen in NYC',
 'HuggingFace ist ein Unternehmen']

In [6]:
model.framework

'pt'

## PySpark

In [7]:
import os
from pathlib import Path
from datasets import load_dataset

In [8]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark import SparkConf

In [None]:
import os
conda_env = os.environ.get("CONDA_PREFIX")

conf = SparkConf()
conf.set("spark.task.maxFailures", "1")
conf.set("spark.driver.memory", "8g")
conf.set("spark.executor.memory", "8g")
conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
conf.set("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled", "false")
conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "true")
conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
conf.set("spark.python.worker.reuse", "true")
# Create Spark Session
spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

In [10]:
# load IMDB reviews (test) dataset
data = load_dataset("imdb", split="test")

In [11]:
lines = []
for example in data:
    lines.append([example["text"].split(".")[0]])

len(lines)

25000

### Create PySpark DataFrame

In [12]:
df = spark.createDataFrame(lines, ['lines']).repartition(10)
df.schema

StructType([StructField('lines', StringType(), True)])

In [13]:
df.take(1)

                                                                                

[Row(lines="i do not understand at all why this movie received such good grades from critics - - i've seen tens of documentaries (on TV) about the wine world which were much much better when (if) you watch it, please think of two very annoying aspects of mondovino : first, the filming is just awful and terrible and upsetting : to me, it looked like the guy behind the camera just received the material and was playing with it : plenty of zooms (for no purpose other than pushing the button in/out) for instance - - i almost stopped to watch it because of that ! secondly, the interviewer (the director i think) is not really relevant : he looks like and ask questions like a boy scout, not like a journalist, even if the general idea and themes would have been interesting, too bad conclusion: overrated documentary, maybe only for guys who do not know nothing about wine => not recommended at all (2/10)")]

### Save the test dataset as parquet files

In [14]:
df.write.mode("overwrite").parquet("imdb_test")

### Check arrow memory configuration

In [15]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "512")
# This line will fail if the vectorized reader runs out of memory
assert len(df.head()) > 0, "`df` should not be empty"

## Inference using Spark DL API (PyTorch)
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [16]:
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [17]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [18]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)
df.show(truncate=120)
df.count()

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|                                       This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                                                                   I was very disappointed by this movie|
|                                                                             I think vampire movies (usually) are wicked|
|                           Though not a complete waste of time, 'Eighteen' really wasn't all sweet as it pretended to be|
|This film did well at the box office, and the producers of this mess thought the stars had such good chemistry in thi...|
|               

100

In [19]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100).cache()

In [20]:
df1.count()

100

In [21]:
df1.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|          Translate English to German: This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                                      Translate English to German: I was very disappointed by this movie|
|                                                Translate English to German: I think vampire movies (usually) are wicked|
|Translate English to German: Though not a complete waste of time, 'Eighteen' really wasn't all sweet as it pretended ...|
|Translate English to German: This film did well at the box office, and the producers of this mess thought the stars h...|
|               

In [22]:
def predict_batch_fn():
    import numpy as np
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    tokenizer = T5Tokenizer.from_pretrained("t5-small")

    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()   # convert 2d numpy array of string into flattened python list
        input_ids = tokenizer(flattened, 
                              padding="longest", 
                              max_length=128,
                              return_tensors="pt").input_ids
        output_ids = model.generate(input_ids)
        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in output_ids])
        print("predict: {}".format(len(flattened)))
        return string_outputs
    
    return predict

In [23]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=10)

In [24]:
%%time
# first pass caches model/fn
preds = df1.withColumn("preds", generate(struct("input")))
results = preds.collect()

[Stage 21:>                                                         (0 + 1) / 1]

CPU times: user 7.14 ms, sys: 7.75 ms, total: 14.9 ms
Wall time: 6.96 s


                                                                                

In [25]:
%%time
preds = df1.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 23:>                                                         (0 + 1) / 1]

CPU times: user 5.07 ms, sys: 5.08 ms, total: 10.1 ms
Wall time: 6.93 s


                                                                                

In [26]:
%%time
preds = df1.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 25:>                                                         (0 + 1) / 1]

CPU times: user 6.79 ms, sys: 4.23 ms, total: 11 ms
Wall time: 7.11 s


                                                                                

In [27]:
preds.show(truncate=60)

[Stage 27:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to German: This is so overly clichéd yo...|   Das ist so übertrieben klischeehaft, dass Sie es nach den|
|Translate English to German: I was very disappointed by t...|                    Ich war sehr enttäuscht über diesen Film|
|Translate English to German: I think vampire movies (usua...|Ich denke, dass die Vampire-Filme (normalerweise) schlech...|
|Translate English to German: Though not a complete waste ...|Obwohl es sich nicht um eine komplette Verschwendung von ...|
|Translate English to German: This film did well at the bo...|Dieser Film hat sich gut an der Boxoffice ereignet, und d...|
|Transla

                                                                                

In [28]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100).cache()

In [29]:
df2.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|          Translate English to French: This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                                      Translate English to French: I was very disappointed by this movie|
|                                                Translate English to French: I think vampire movies (usually) are wicked|
|Translate English to French: Though not a complete waste of time, 'Eighteen' really wasn't all sweet as it pretended ...|
|Translate English to French: This film did well at the box office, and the producers of this mess thought the stars h...|
|               

In [30]:
%%time
# first pass caches model/fn
preds = df2.withColumn("preds", generate(struct("input")))
result = preds.collect()

[Stage 33:>                                                         (0 + 1) / 1]

CPU times: user 2.51 ms, sys: 1.44 ms, total: 3.96 ms
Wall time: 6.94 s


                                                                                

In [31]:
%%time
preds = df2.withColumn("preds", generate("input"))
result = preds.collect()

[Stage 35:>                                                         (0 + 1) / 1]

CPU times: user 3.08 ms, sys: 2.83 ms, total: 5.91 ms
Wall time: 5.74 s


                                                                                

In [32]:
%%time
preds = df2.withColumn("preds", generate(col("input")))
result = preds.collect()

[Stage 37:>                                                         (0 + 1) / 1]

CPU times: user 4.31 ms, sys: 4.11 ms, total: 8.42 ms
Wall time: 5.3 s


                                                                                

In [33]:
preds.show(truncate=60)

[Stage 39:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to French: This is so overly clichéd yo...|                 Vous ne pouvez pas en tirer d'un tel cliché|
|Translate English to French: I was very disappointed by t...|                               Je suis très déçu par ce film|
|Translate English to French: I think vampire movies (usua...|Je pense que les films vampires (habituellement) sont méc...|
|Translate English to French: Though not a complete waste ...|    Bien qu'il ne soit pas un gaspillage complet de temps, '|
|Translate English to French: This film did well at the bo...|               Ce film a bien avancé à la salle de cinéma et|
|Transla

                                                                                

### Using Triton Inference Server

Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:
```
conda create -n huggingface-torch -c conda-forge python=3.10.0
conda activate huggingface-torch

export PYTHONNOUSERSITE=True
pip install numpy==1.26.4 conda-pack sentencepiece sentence_transformers transformers

conda-pack  # huggingface-torch.tar.gz
```

In [34]:
import os

In [35]:
%%bash
# copy custom model to expected layout for Triton
rm -rf models
mkdir -p models
cp -r models_config/hf_generation_torch models

# add custom execution environment
cp huggingface-torch.tar.gz models

#### Start Triton Server on each executor

In [36]:
num_executors = 1
triton_models_dir = "{}/models".format(os.getcwd())
huggingface_cache_dir = "{}/.cache/huggingface".format(os.path.expanduser('~'))
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:24.08-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            environment=[
                "TRANSFORMERS_CACHE=/cache"
            ],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="1G",
            volumes={
                triton_models_dir: {"bind": "/models", "mode": "ro"},
                huggingface_cache_dir: {"bind": "/cache", "mode": "rw"}
            }
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)

    return [True]

nodeRDD.barrier().mapPartitions(start_triton).collect()

                                                                                

[True]

#### Run inference

In [37]:
import pandas as pd
from functools import partial
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [38]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100).cache()

In [39]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [40]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100)

In [41]:
df1.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|          Translate English to German: This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                                      Translate English to German: I was very disappointed by this movie|
|                                                Translate English to German: I think vampire movies (usually) are wicked|
|Translate English to German: Though not a complete waste of time, 'Eighteen' really wasn't all sweet as it pretended ...|
|Translate English to German: This film did well at the box office, and the producers of this mess thought the stars h...|
|               

In [42]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool8),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [43]:
generate = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="hf_generation_torch"),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=100)

In [44]:
%%time
# first pass caches model/fn
preds = df1.withColumn("preds", generate(struct("input")))
results = preds.collect()

[Stage 45:>                                                         (0 + 1) / 1]

CPU times: user 4.8 ms, sys: 0 ns, total: 4.8 ms
Wall time: 1.91 s


                                                                                

In [45]:
%%time
preds = df1.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 47:>                                                         (0 + 1) / 1]

CPU times: user 4.33 ms, sys: 1.1 ms, total: 5.43 ms
Wall time: 1.52 s


                                                                                

In [46]:
%%time
preds = df1.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 49:>                                                         (0 + 1) / 1]

CPU times: user 4.25 ms, sys: 780 μs, total: 5.03 ms
Wall time: 1.61 s


                                                                                

In [47]:
preds.show(truncate=60)

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to German: This is so overly clichéd yo...|   Das ist so übertrieben klischeehaft, dass Sie es nach den|
|Translate English to German: I was very disappointed by t...|                    Ich war sehr enttäuscht über diesen Film|
|Translate English to German: I think vampire movies (usua...|Ich denke, dass die Vampire-Filme (normalerweise) schlech...|
|Translate English to German: Though not a complete waste ...|Obwohl es sich nicht um eine komplette Verschwendung von ...|
|Translate English to German: This film did well at the bo...|Dieser Film hat sich gut an der Boxoffice ereignet, und d...|
|Transla

In [48]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100).cache()

24/10/03 00:32:50 WARN CacheManager: Asked to cache already cached data.


In [49]:
df2.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|          Translate English to French: This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                                      Translate English to French: I was very disappointed by this movie|
|                                                Translate English to French: I think vampire movies (usually) are wicked|
|Translate English to French: Though not a complete waste of time, 'Eighteen' really wasn't all sweet as it pretended ...|
|Translate English to French: This film did well at the box office, and the producers of this mess thought the stars h...|
|               

In [50]:
%%time
preds = df2.withColumn("preds", generate(struct("input")))
results = preds.collect()

[Stage 55:>                                                         (0 + 1) / 1]

CPU times: user 1.16 ms, sys: 2.87 ms, total: 4.03 ms
Wall time: 1.88 s


                                                                                

In [51]:
%%time
preds = df2.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 57:>                                                         (0 + 1) / 1]

CPU times: user 3.45 ms, sys: 544 μs, total: 4 ms
Wall time: 1.7 s


                                                                                

In [52]:
%%time
preds = df2.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 59:>                                                         (0 + 1) / 1]

CPU times: user 4.81 ms, sys: 5.95 ms, total: 10.8 ms
Wall time: 1.62 s


                                                                                

In [53]:
preds.show(truncate=60)

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to French: This is so overly clichéd yo...|                 Vous ne pouvez pas en tirer d'un tel cliché|
|Translate English to French: I was very disappointed by t...|                               Je suis très déçu par ce film|
|Translate English to French: I think vampire movies (usua...|Je pense que les films vampires (habituellement) sont méc...|
|Translate English to French: Though not a complete waste ...|    Bien qu'il ne soit pas un gaspillage complet de temps, '|
|Translate English to French: This film did well at the bo...|               Ce film a bien avancé à la salle de cinéma et|
|Transla

#### Stop Triton Server on each executor

In [54]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.barrier().mapPartitions(stop_triton).collect()

                                                                                

[True]

In [55]:
spark.stop()