# PySpark Huggingface Inferencing
## Conditional generation with PyTorch

From: https://huggingface.co/docs/transformers/model_doc/t5

In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Enabling Huggingface tokenizer parallelism so that it is not automatically disabled with Python parallelism. See [this thread](https://github.com/huggingface/transformers/issues/5486) for more info. 

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [None]:
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

task_prefix = "translate English to German: "

lines = [
    "The house is wonderful",
    "Welcome to NYC",
    "HuggingFace is a company"
]

input_sequences = [task_prefix + l for l in lines]

In [2]:
input_ids = tokenizer(input_sequences, 
                      padding="longest", 
                      max_length=512,
                      truncation=True,
                      return_tensors="pt").input_ids

outputs = model.generate(input_ids, max_length=20)

In [3]:
[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

['Das Haus ist wunderbar',
 'Willkommen in NYC',
 'HuggingFace ist ein Unternehmen']

In [4]:
model.framework

'pt'

## PySpark

In [1]:
import os
from pathlib import Path
from datasets import load_dataset

In [2]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark import SparkConf

In [3]:
import os
conda_env = os.environ.get("CONDA_PREFIX")

conf = SparkConf()
if 'spark' not in globals():
    # If Spark is not already started with Jupyter, attach to Spark Standalone
    import socket
    hostname = socket.gethostname()
    conf.setMaster(f"spark://{hostname}:7077") # assuming Master is on default port 7077
conf.set("spark.task.maxFailures", "1")
conf.set("spark.driver.memory", "8g")
conf.set("spark.executor.memory", "8g")
conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
conf.set("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled", "false")
conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "true")
conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "512")
conf.set("spark.python.worker.reuse", "true")
# Create Spark Session
spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

24/10/10 00:10:48 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
24/10/10 00:10:48 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/10 00:10:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
# load IMDB reviews (test) dataset
data = load_dataset("imdb", split="test")

In [5]:
lines = []
for example in data:
    lines.append([example["text"].split(".")[0]])

len(lines)

25000

### Create PySpark DataFrame

In [6]:
df = spark.createDataFrame(lines, ['lines']).repartition(8)
df.schema

StructType([StructField('lines', StringType(), True)])

In [7]:
df.take(1)

                                                                                

[Row(lines='(Some Spoilers) Dull as dishwater slasher flick that has this deranged homeless man Harry, Darwyn Swalve, out murdering real-estate agent all over the city of L')]

### Save the test dataset as parquet files

In [8]:
df.write.mode("overwrite").parquet("imdb_test")

### Check arrow memory configuration

In [9]:
if int(spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch")) > 512:
    print("Decreasing `spark.sql.execution.arrow.maxRecordsPerBatch` to ensure the vectorized reader won't run out of memory")
    spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "512")
assert len(df.head()) > 0, "`df` should not be empty"

## Inference using Spark DL API
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [10]:
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [11]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [12]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)
df.show(truncate=120)
df.count()

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|                                       This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                                                      I am a big fan of The ABC Movies of the Week genre|
|In the early 1990's "Step-by-Step" came as a tedious combination of the ultra-cheesy "Full House" and the long-defunc...|
|When The Spirits Within was released, all you heard from Final Fantasy fans was how awful the movie was because it di...|
|                                                                    I like to think of myself as a bad movie connoisseur|
|This film did w

100

In [13]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100).cache()

In [14]:
df1.count()

100

In [15]:
df1.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|          Translate English to German: This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                         Translate English to German: I am a big fan of The ABC Movies of the Week genre|
|Translate English to German: In the early 1990's "Step-by-Step" came as a tedious combination of the ultra-cheesy "Fu...|
|Translate English to German: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|
|                                       Translate English to German: I like to think of myself as a bad movie connoisseur|
|Translate Engli

In [16]:
def predict_batch_fn():
    import numpy as np
    from transformers import T5ForConditionalGeneration, T5Tokenizer

    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    tokenizer = T5Tokenizer.from_pretrained("t5-small")

    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()   # convert 2d numpy array of string into flattened python list
        input_ids = tokenizer(flattened, 
                              padding="longest", 
                              max_length=128,
                              truncation=True,
                              return_tensors="pt").input_ids
        output_ids = model.generate(input_ids, max_length=20)
        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in output_ids])
        print("predict: {}".format(len(flattened)))
        
        return string_outputs
    
    return predict

In [17]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=10)

In [18]:
%%time
# first pass caches model/fn
preds = df1.withColumn("preds", generate(struct("input")))
results = preds.collect()

[Stage 21:>                                                         (0 + 1) / 1]

CPU times: user 6.58 ms, sys: 4.68 ms, total: 11.3 ms
Wall time: 7.41 s


                                                                                

In [19]:
%%time
preds = df1.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 23:>                                                         (0 + 1) / 1]

CPU times: user 1.87 ms, sys: 1.8 ms, total: 3.67 ms
Wall time: 5.71 s


                                                                                

In [20]:
%%time
preds = df1.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 25:>                                                         (0 + 1) / 1]

CPU times: user 2.99 ms, sys: 1.42 ms, total: 4.42 ms
Wall time: 5.69 s


                                                                                

In [21]:
preds.show(truncate=60)

[Stage 27:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to German: This is so overly clichéd yo...|   Das ist so übertrieben klischeehaft, dass Sie es nach den|
|Translate English to German: I am a big fan of The ABC Mo...|       Ich bin ein großer Fan von The ABC Movies of the Week|
|Translate English to German: In the early 1990's "Step-by...|          Anfang der 1990er Jahre kam "Step-by-Step" als müh|
|Translate English to German: When The Spirits Within was ...|Als The Spirits Within veröffentlicht wurde, hörten Sie v...|
|Translate English to German: I like to think of myself as...|           Ich halte mich gerne als schlechter Filmliebhaber|
|Transla

                                                                                

In [22]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100).cache()

In [23]:
df2.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|          Translate English to French: This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                         Translate English to French: I am a big fan of The ABC Movies of the Week genre|
|Translate English to French: In the early 1990's "Step-by-Step" came as a tedious combination of the ultra-cheesy "Fu...|
|Translate English to French: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|
|                                       Translate English to French: I like to think of myself as a bad movie connoisseur|
|Translate Engli

In [24]:
%%time
# first pass caches model/fn
preds = df2.withColumn("preds", generate(struct("input")))
result = preds.collect()

[Stage 33:>                                                         (0 + 1) / 1]

CPU times: user 2.46 ms, sys: 2.2 ms, total: 4.67 ms
Wall time: 7.38 s


                                                                                

In [25]:
%%time
preds = df2.withColumn("preds", generate("input"))
result = preds.collect()

[Stage 35:>                                                         (0 + 1) / 1]

CPU times: user 3.34 ms, sys: 1.13 ms, total: 4.47 ms
Wall time: 6.1 s


                                                                                

In [26]:
%%time
preds = df2.withColumn("preds", generate(col("input")))
result = preds.collect()

[Stage 37:>                                                         (0 + 1) / 1]

CPU times: user 1.72 ms, sys: 2.89 ms, total: 4.6 ms
Wall time: 5.93 s


                                                                                

In [27]:
preds.show(truncate=60)

[Stage 39:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to French: This is so overly clichéd yo...|                 Vous ne pouvez pas en tirer d'un tel cliché|
|Translate English to French: I am a big fan of The ABC Mo...|    Je suis un grand fan du genre The ABC Movies of the Week|
|Translate English to French: In the early 1990's "Step-by...|          Au début des années 1990, «Step-by-Step» a été une|
|Translate English to French: When The Spirits Within was ...|Lorsque The Spirits Within a été publié, tout ce que vous...|
|Translate English to French: I like to think of myself as...|       Je me considère comme un mauvais réalisateur de films|
|Transla

                                                                                

### Using Triton Inference Server

Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:
```
conda create -n huggingface-torch -c conda-forge python=3.10.0
conda activate huggingface-torch

export PYTHONNOUSERSITE=True
pip install numpy==1.26.4 conda-pack sentencepiece sentence_transformers transformers

conda-pack  # huggingface-torch.tar.gz
```

In [28]:
import os

In [29]:
%%bash
# copy custom model to expected layout for Triton
rm -rf models
mkdir -p models
cp -r models_config/hf_generation_torch models

# add custom execution environment
cp huggingface-torch.tar.gz models

#### Start Triton Server on each executor

In [30]:
num_executors = 1
triton_models_dir = "{}/models".format(os.getcwd())
huggingface_cache_dir = "{}/.cache/huggingface".format(os.path.expanduser('~'))
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:24.08-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            environment=[
                "TRANSFORMERS_CACHE=/cache"
            ],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="1G",
            volumes={
                triton_models_dir: {"bind": "/models", "mode": "ro"},
                huggingface_cache_dir: {"bind": "/cache", "mode": "rw"}
            }
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)

    return [True]

nodeRDD.barrier().mapPartitions(start_triton).collect()

                                                                                

[True]

#### Run inference

In [31]:
import pandas as pd
from functools import partial
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [32]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100).cache()

In [33]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [34]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100)

In [35]:
df1.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|          Translate English to German: This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                         Translate English to German: I am a big fan of The ABC Movies of the Week genre|
|Translate English to German: In the early 1990's "Step-by-Step" came as a tedious combination of the ultra-cheesy "Fu...|
|Translate English to German: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|
|                                       Translate English to German: I like to think of myself as a bad movie connoisseur|
|Translate Engli

In [36]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool8),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [37]:
generate = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="hf_generation_torch"),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=100)

In [38]:
%%time
# first pass caches model/fn
preds = df1.withColumn("preds", generate(struct("input")))
results = preds.collect()

[Stage 45:>                                                         (0 + 1) / 1]

CPU times: user 4.61 ms, sys: 1.26 ms, total: 5.87 ms
Wall time: 2.04 s


                                                                                

In [39]:
%%time
preds = df1.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 47:>                                                         (0 + 1) / 1]

CPU times: user 3.16 ms, sys: 641 μs, total: 3.8 ms
Wall time: 1.58 s


                                                                                

In [40]:
%%time
preds = df1.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 49:>                                                         (0 + 1) / 1]

CPU times: user 1.91 ms, sys: 2.38 ms, total: 4.29 ms
Wall time: 1.75 s


                                                                                

In [41]:
preds.show(truncate=60)

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to German: This is so overly clichéd yo...|   Das ist so übertrieben klischeehaft, dass Sie es nach den|
|Translate English to German: I am a big fan of The ABC Mo...|       Ich bin ein großer Fan von The ABC Movies of the Week|
|Translate English to German: In the early 1990's "Step-by...|          Anfang der 1990er Jahre kam "Step-by-Step" als müh|
|Translate English to German: When The Spirits Within was ...|Als The Spirits Within veröffentlicht wurde, hörten Sie v...|
|Translate English to German: I like to think of myself as...|           Ich halte mich gerne als schlechter Filmliebhaber|
|Transla

In [42]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100).cache()

24/10/10 00:12:21 WARN CacheManager: Asked to cache already cached data.


In [43]:
df2.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|          Translate English to French: This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                         Translate English to French: I am a big fan of The ABC Movies of the Week genre|
|Translate English to French: In the early 1990's "Step-by-Step" came as a tedious combination of the ultra-cheesy "Fu...|
|Translate English to French: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|
|                                       Translate English to French: I like to think of myself as a bad movie connoisseur|
|Translate Engli

In [44]:
%%time
preds = df2.withColumn("preds", generate(struct("input")))
results = preds.collect()

[Stage 55:>                                                         (0 + 1) / 1]

CPU times: user 3.4 ms, sys: 2.75 ms, total: 6.14 ms
Wall time: 1.96 s


                                                                                

In [45]:
%%time
preds = df2.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 57:>                                                         (0 + 1) / 1]

CPU times: user 3.76 ms, sys: 897 μs, total: 4.66 ms
Wall time: 1.61 s


                                                                                

In [46]:
%%time
preds = df2.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 59:>                                                         (0 + 1) / 1]

CPU times: user 2.61 ms, sys: 2.26 ms, total: 4.87 ms
Wall time: 1.67 s


                                                                                

In [47]:
preds.show(truncate=60)

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to French: This is so overly clichéd yo...|                 Vous ne pouvez pas en tirer d'un tel cliché|
|Translate English to French: I am a big fan of The ABC Mo...|    Je suis un grand fan du genre The ABC Movies of the Week|
|Translate English to French: In the early 1990's "Step-by...|          Au début des années 1990, «Step-by-Step» a été une|
|Translate English to French: When The Spirits Within was ...|Lorsque The Spirits Within a été publié, tout ce que vous...|
|Translate English to French: I like to think of myself as...|       Je me considère comme un mauvais réalisateur de films|
|Transla

#### Stop Triton Server on each executor

In [48]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.barrier().mapPartitions(stop_triton).collect()

                                                                                

[True]

In [49]:
spark.stop()