# PySpark DL Inference on Dataproc
### Conditional generation with Huggingface

In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation.  
From: https://huggingface.co/docs/transformers/model_doc/t5

In [1]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, struct
from pyspark.ml.functions import predict_batch_udf
from pyspark import SparkConf

In [None]:
import glob
import time
import json
import pandas as pd
import datasets
from datasets import load_dataset
datasets.disable_progress_bars()

Load the IMBD Movie Reviews dataset from Huggingface.

In [3]:
dataset = load_dataset("imdb", split="test")
dataset = dataset.to_pandas().drop(columns="label")

## PySpark

In [None]:
def initialize_spark():
    rapids_jar = glob.glob("/usr/lib/spark/jars/rapids-4-spark*")[-1]
    python_path="/opt/conda/miniconda3/python3.11"
    conda_lib_path="/opt/conda/miniconda3/lib"

    conf = SparkConf()
    conf.set("spark.executor.cores", "8")
    conf.set("spark.task.resource.gpu.amount", "0.125")
    conf.set("spark.sql.shuffle.partitions", "200")
    conf.set("spark.task.maxFailures", "1")
    conf.set("spark.rpc.message.maxSize", "1024")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")
    conf.set("spark.rapids.ml.uvm.enabled", "true")
    conf.set("spark.executorEnv.LD_LIBRARY_PATH", f"{conda_lib_path}:$LD_LIBRARY_PATH") # PyTriton needs libpython3.11.so
    conf.set("spark.executorEnv.PYTHONPATH", f"{rapids_jar}:{python_path}")
    conf.set("spark.rapids.memory.gpu.minAllocFraction", "0.0001")
    conf.set("spark.plugins", "com.nvidia.spark.SQLPlugin")
    conf.set("spark.locality.wait", "0s")
    conf.set("spark.sql.cache.serializer", "com.nvidia.spark.ParquetCachedBatchSerializer")
    conf.set("spark.rapids.memory.gpu.pooling.enabled", "false")
    conf.set("spark.sql.execution.sortBeforeRepartition", "false")
    conf.set("spark.rapids.sql.format.parquet.reader.type", "MULTITHREADED")
    conf.set("spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel", "20")
    conf.set("spark.rapids.sql.multiThreadedRead.numThreads", "20")
    conf.set("spark.rapids.sql.python.gpu.enabled", "true")
    conf.set("spark.rapids.memory.pinnedPool.size", "2G")
    conf.set("spark.python.daemon.module", "rapids.daemon")
    conf.set("spark.rapids.sql.batchSizeBytes", "512m")
    conf.set("spark.sql.adaptive.enabled", "false")
    conf.set("spark.sql.files.maxPartitionBytes", "512m")
    conf.set("spark.rapids.sql.concurrentGpuTasks", "2")
    conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "20000")
    conf.set("spark.rapids.sql.explain", "NONE")
    
    spark = SparkSession.builder.appName("spark-dl").config(conf=conf).getOrCreate()
    return spark

if 'spark' not in globals():
    print("Initializing Spark session.")
    spark = initialize_spark()
else:
    print("Using existing Spark session.")

Initializing Spark session.


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/11/13 18:35:31 INFO SparkEnv: Registering MapOutputTracker
24/11/13 18:35:31 INFO SparkEnv: Registering BlockManagerMaster
24/11/13 18:35:31 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
24/11/13 18:35:31 INFO SparkEnv: Registering OutputCommitCoordinator
24/11/13 18:35:33 WARN RapidsPluginUtils: RAPIDS Accelerator 24.08.1 using cudf 24.08.0, private revision 9fac64da220ddd6bf5626bd7bd1dd74c08603eac
24/11/13 18:35:33 WARN RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.


#### Create PySpark DataFrame

In [5]:
df = spark.createDataFrame(dataset).repartition(16)
df.schema

StructType([StructField('text', StringType(), True)])

In [6]:
df.count()

                                                                                

25000

In [7]:
df.take(1)

24/11/13 18:36:05 WARN TaskSetManager: Stage 3 contains a task of very large size (2020 KiB). The maximum recommended task size is 1000 KiB.


[Row(text="Miles O'keefe stars as Ator, a loin-clothed hero who resembles a Chippendale's dancer. The Conan-wannabe must do battle with an evil guy in a Cher wig, and protect the Earth from the Geometric Nucleus, a sort of primitive atomic bomb. Watch closely for visible sunglasses and tire-tracks. Mystery Science Theater 3000 made fun of it under the title CAVE DWELLERS.")]

In [8]:
data_path = "/tmp/datasets/imdb_test"
df.write.mode("overwrite").parquet(data_path)

24/11/13 18:36:10 WARN TaskSetManager: Stage 5 contains a task of very large size (2020 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

#### Load and preprocess DataFrame

We'll take the first sentence from each sample as our target for translation.

In [9]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [10]:
df = spark.read.parquet(data_path).limit(1024).repartition(16)
df.show(truncate=100)

                                                                                

+----------------------------------------------------------------------------------------------------+
|                                                                                                text|
+----------------------------------------------------------------------------------------------------+
|G&M started a the odd couple downstairs in Man About the House and went on to amusing the nation ...|
|Simply miserable Lana Turner-Ezio Pinza vehicle. Pinza had a beautiful voice but he rarely uses i...|
|this 2.5 hour diluted snore-fest appears to be one of the poorest excuses for an adaptation, ever...|
|seriously, if i wanted to make a movie that makes zero sense, never will, and features lesbian sc...|
|I'm usually quite tolerant of movies, and very easily entertained, however this movie was dreadfu...|
|Worse than mediocre thriller about an abused wife who goes on the lam after she is linked circums...|
|**Could be considered some mild spoilers, but no more than in anyone els

In [11]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()
input_df.show(truncate=100)

[Stage 14:>                                                         (0 + 1) / 1]

+----------------------------------------------------------------------------------------------------+
|                                                                                               input|
+----------------------------------------------------------------------------------------------------+
|                              translate English to French: 1983's "Frightmare" is an odd little film|
|                                translate English to French: The film made no sense to me whatsoever|
|translate English to French: I loved the first 15 minutes, and I loved some of the dialogue in th...|
|                    translate English to French: Don't tell me this film was funny or a little funny|
|                                translate English to French: I was really disappointed by this movie|
|       translate English to French: It's about time for a female boxing flick, but this one ain't it|
|translate English to French: I'm actually too drained to write this revi

                                                                                

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):

- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays 
- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function

In [12]:
def predict_batch_fn():
    import numpy as np
    import torch
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    from pyspark import TaskContext

    print(f"Initializing model on worker {TaskContext.get().partitionId()}.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {device} device.")
    model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
    tokenizer = T5Tokenizer.from_pretrained("t5-small")

    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()
        inputs = tokenizer(flattened, 
                           padding=True,
                           return_tensors="pt").to(device)
        outputs = model.generate(input_ids=inputs["input_ids"],
                                 attention_mask=inputs["attention_mask"],
                                 max_length=128)
        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])
        print("predict: {}".format(len(flattened)))
        return string_outputs
    
    return predict

In [13]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=32)

In [14]:
start_time = time.time()
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()
print(f"{time.time() - start_time} seconds")



19.314849376678467 seconds


                                                                                

In [15]:
start_time = time.time()
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()
print(f"{time.time() - start_time} seconds")



10.648898124694824 seconds


                                                                                

In [16]:
start_time = time.time()
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()
print(f"{time.time() - start_time} seconds")



10.734241485595703 seconds


                                                                                

In [17]:
preds.show(truncate=50)

[Stage 26:>                                                         (0 + 1) / 1]

+--------------------------------------------------+--------------------------------------------------+
|                                             input|                                             preds|
+--------------------------------------------------+--------------------------------------------------+
|translate English to French: 1983's "Frightmare...|Le film "Frightmare" de 1983 est un petit film ...|
|translate English to French: The film made no s...|                 Le film n'a rien de sens pour moi|
|translate English to French: I loved the first ...|J'ai aimé les 15 premières minutes et j'ai aimé...|
|translate English to French: Don't tell me this...|Ne me dit pas que ce film était amusant ou un p...|
|translate English to French: I was really disap...|                Je suis vraiment déçu par ce film.|
|translate English to French: It's about time fo...|Il est temps de faire un féminisme de boxe, mai...|
|translate English to French: I'm actually too d...|Je suis en f

                                                                                

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

In [18]:
from functools import partial

In [19]:
def triton_server():
    import signal
    import numpy as np
    import torch
    from transformers import T5Tokenizer, T5ForConditionalGeneration
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton
    from pyspark import TaskContext

    with Triton() as triton:
        print(f"SERVER: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.")
        tokenizer = T5Tokenizer.from_pretrained("t5-small")
        model = T5ForConditionalGeneration.from_pretrained("t5-small")
        
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"SERVER: Using {DEVICE} device.")
        model = model.to(DEVICE)

        @batch
        def _infer_fn(**inputs):
            sentences = np.squeeze(inputs["text"]).tolist()
            print(f"SERVER: Received batch of size {len(sentences)}")
            decoded_sentences = [s.decode("utf-8") for s in sentences]
            inputs = tokenizer(decoded_sentences,
                            padding=True,
                            return_tensors="pt").to(DEVICE)
            output_ids = model.generate(input_ids=inputs["input_ids"],
                                        attention_mask=inputs["attention_mask"],
                                        max_length=128)
            outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])
            return {
                "translations": outputs,
            }

        triton.bind(
            model_name="ConditionalGeneration",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="text", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="translations", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=64,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def stop_triton(signum, frame):
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

def start_triton(url, model_name):
    import socket
    from multiprocessing import Process
    from pytriton.client import ModelClient

    hostname = socket.gethostname()
    process = Process(target=triton_server)
    process.start()

    client = ModelClient(url, model_name)
    ready = False
    while not ready:
        try:
            client.wait_for_server(5)
            ready = True
        except Exception as e:
            print(f"Waiting for server to be ready: {e}")
    
    return [(hostname, process.pid)]

#### Start Triton servers

To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU.  
Stage-level scheudling requires a value for `spark.executor.cores`, and requires that `spark.executor.resource.gpu.amount` <= 1.

In [20]:
def _use_stage_level_scheduling(spark, rdd):

    if spark.version < "3.4.0":
        raise Exception("Stage-level scheduling is not supported in Spark < 3.4.0")

    executor_cores = spark.conf.get("spark.executor.cores")
    assert executor_cores is not None, "spark.executor.cores is not set"
    executor_gpus = spark.conf.get("spark.executor.resource.gpu.amount")
    assert executor_gpus is not None and int(executor_gpus) <= 1, "spark.executor.resource.gpu.amount must be set and <= 1"

    from pyspark.resource.profile import ResourceProfileBuilder
    from pyspark.resource.requests import TaskResourceRequests

    # each training task requires cpu cores > total executor cores/2 which can
    # ensure each training task be sent to different executor.
    #
    # Please note that we can't set task_cores to the value which is smaller than total executor cores/2
    # because only task_gpus can't ensure the tasks be sent to different executor even task_gpus=1.0
    #
    # If spark-rapids enabled. we don't allow other ETL task running alongside training task to avoid OOM
    spark_plugins = spark.conf.get("spark.plugins", " ")
    assert spark_plugins is not None
    spark_rapids_sql_enabled = spark.conf.get("spark.rapids.sql.enabled", "true")
    assert spark_rapids_sql_enabled is not None

    task_cores = (
        int(executor_cores)
        if "com.nvidia.spark.SQLPlugin" in spark_plugins
        and "true" == spark_rapids_sql_enabled.lower()
        else (int(executor_cores) // 2) + 1
    )
    # task_gpus means how many slots per gpu address the task requires,
    # it does mean how many gpus it would like to require, so it can be any value of (0, 0.5] or 1.
    task_gpus = 1.0

    treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
    rp = ResourceProfileBuilder().require(treqs).build

    print(f"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})")

    return rdd.withResources(rp)

In [21]:
# Start servers (2 node cluster)
num_nodes = 2
url = "localhost"
model_name = "ConditionalGeneration"

sc = spark.sparkContext
nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)

Reqesting stage-level resources: (cores=8, gpu=1.0)


In [None]:
pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(url, model_name)).collectAsMap()
# print("Triton Server PIDs:\n", json.dumps(pids, indent=4))

In [23]:
def triton_fn(url, model_name, init_timeout_s):
    import numpy as np
    from pytriton.client import ModelClient

    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, init_timeout_s=init_timeout_s) as client:
            flattened = np.squeeze(inputs).tolist() 
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["translations"], -1)
            return result_data
        
    return infer_batch

#### Load and preprocess DataFrame

In [24]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [25]:
df = spark.read.parquet(data_path).limit(1024).repartition(16)

In [26]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()

24/11/13 18:38:46 WARN CacheManager: Asked to cache already cached data.


In [27]:
generate = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name, init_timeout_s=600),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=32)

In [28]:
%%time
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()



CPU times: user 23 ms, sys: 3.57 ms, total: 26.5 ms
Wall time: 17.1 s


                                                                                

In [29]:
%%time
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()



CPU times: user 15.8 ms, sys: 4.63 ms, total: 20.5 ms
Wall time: 15.2 s


                                                                                

In [30]:
%%time
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()



CPU times: user 19.9 ms, sys: 4.27 ms, total: 24.2 ms
Wall time: 15.6 s


                                                                                

In [31]:
preds.show(truncate=50)

[Stage 40:>                                                         (0 + 1) / 1]

+--------------------------------------------------+--------------------------------------------------+
|                                             input|                                             preds|
+--------------------------------------------------+--------------------------------------------------+
|translate English to French: 1983's "Frightmare...|Le film "Frightmare" de 1983 est un petit film ...|
|translate English to French: The film made no s...|                 Le film n'a rien de sens pour moi|
|translate English to French: I loved the first ...|J'ai aimé les 15 premières minutes et j'ai aimé...|
|translate English to French: Don't tell me this...|Ne me dit pas que ce film était amusant ou un p...|
|translate English to French: I was really disap...|                Je suis vraiment déçu par ce film.|
|translate English to French: It's about time fo...|Il est temps de faire un féminisme de boxe, mai...|
|translate English to French: I'm actually too d...|Je suis en f

                                                                                

#### Shut down servers

In [32]:
def stop_triton(pids):
    import os
    import socket
    import signal
    import time 
    
    hostname = socket.gethostname()
    pid = pids.get(hostname, None)
    assert pid is not None, "Could not find pid for {hostname}"
    os.kill(pid, signal.SIGTERM)
    time.sleep(7)
    
    for _ in range(5):
        try:
            os.kill(pid, 0)
        except OSError:
            return [True]
        time.sleep(3)

    return [False]

shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)
shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()

Reqesting stage-level resources: (cores=8, gpu=1.0)


                                                                                

[True, True]