# PySpark DL Inference on Databricks
### Conditional generation with Huggingface

In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation.  
From: https://huggingface.co/docs/transformers/model_doc/t5

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Enabling Huggingface tokenizer parallelism so that it is not automatically disabled with Python parallelism.
# See (https://github.com/huggingface/transformers/issues/5486) for more info. 
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [None]:
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

task_prefix = "translate English to German: "

lines = [
    "The house is wonderful",
    "Welcome to NYC",
    "HuggingFace is a company"
]

input_sequences = [task_prefix + l for l in lines]

In [None]:
inputs = tokenizer(input_sequences,
                      padding=True, 
                      return_tensors="pt")

outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=128)

In [None]:
[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

## PySpark

In [0]:
from pyspark.sql.types import *
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, struct
from pyspark.ml.functions import predict_batch_udf

In [0]:
import time
import json
import pandas as pd
import datasets
from datasets import load_dataset
datasets.disable_progress_bars()

In [None]:
conf = SparkConf()

if 'spark' not in globals():
    import os
    import socket
    # If Spark is not already started with Jupyter, attach to Spark Standalone
    conda_env = os.environ.get("CONDA_PREFIX")
    hostname = socket.gethostname()
    conf.setMaster(f"spark://{hostname}:7077") # assuming Master is on default port 7077
    conf.set("spark.driver.memory", "8g")
    conf.set("spark.executor.memory", "8g")
    conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
    conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")

conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")

# Create Spark Session
spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

In [0]:
"""
(Optional for Databricks): For large datasets, we can specify the Huggingface dataset cache directory to the cluster's local disk (or DBFS to persist after cluster termination), rather than the default '/' (ephemeral file system at instance root).
The code below specifies local disk, which enables autoscaling up to 5TB.
"""
# LOCAL_DISK_MOUNT = '/local_disk0'
# dbutils.fs.mkdirs(f"file://{LOCAL_DISK_MOUNT}/hf_cache")
# LOCAL_DISK_CACHE_DIR = f'{LOCAL_DISK_MOUNT}/hf_cache/'
# dataset = load_dataset("imdb, split="test", cache_dir=LOCAL_DISK_CACHE_DIR)

"\n(Optional): For large datasets, we can specify the Huggingface dataset cache directory to the cluster's local disk (or DBFS to persist after cluster termination), rather than the default '/' (ephemeral file system at instance root). The code below specifies local disk, which enables autoscaling up to 5TB.\nFor more info, see https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/hugging-face-dataset-download.html.\n"

Load the IMBD Movie Reviews dataset from Huggingface.

In [0]:
dataset = load_dataset("imdb", split="test")
dataset = dataset.to_pandas().drop(columns="label")



#### Create PySpark DataFrame

In [0]:
df = spark.createDataFrame(dataset).repartition(64)
df.schema

StructType([StructField('text', StringType(), True)])

In [0]:
df.count()

25000

In [0]:
df.take(1)

[Row(text="No, no, no, no, no, no, NO! This is not a film, this is an excuse to show people dancing. This is just not good. Even the dancing is slow and not half as entertaining as the mediocre 'Dirty Dancing', let alone any other good dance movie.<br /><br />Is it a love story? Is it a musical? Is it a drama? Is it a comedy? It's not that this movie is a bit of all, it's that this movie fails at everything it attempts to be. The film turns out to be even more meaningless as the film progresses.<br /><br />Acting is terrible from all sides, the screenplay is definitely trying to tell us something about relationship but fails miserably.<br /><br />WATCH FOR THE MOMENT - When Patrick Stewart enters the scene and you think the film might get better as he brightens up the dull atmosphere. For a second.")]

In [0]:
data_path = "/tmp/datasets/imdb_test"
df.write.mode("overwrite").parquet(data_path)

#### Load and preprocess DataFrame

Define our preprocess function. We'll take the first sentence from each sample as our input for translation.

In [0]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [0]:
df = spark.read.parquet(data_path).limit(2048).repartition(64)
df.show(truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                                text|
+----------------------------------------------------------------------------------------------------+
|I am and have been a serious collector of Christmas related movies, TV shows, holidays specials, ...|
|Another sequel! Why on earth do they keep making these? This has got to be the weakest 'franchise...|
|The submarine used was NOT Varangian! 'It' was in fact two boats, P614 and P615, both built for T...|
|John Travolta was excellent as "Michael" in the movie by the same name. I don't think a better po...|
|"Who Will Love My Children" Saddest movie I have ever seen. Definite 10/10. Released on TV in 198...|
|This is really bad, the characters were bland, the story was boring, and there is no sex scene. F...|
|Maybe, like most others who have seen this film long after it's premiere

Let's append a prefix to tell the model to translate English to French:

In [0]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()
input_df.show(truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                               input|
+----------------------------------------------------------------------------------------------------+
|translate English to French: This really is by far the worst movie I've ever seen in my whole lif...|
|translate English to French: Basically, this was obviously designed to be promotional material fo...|
|          translate English to French: From the title, the tag-line, the plot summary on the DVD etc|
|    translate English to French: This is the first feature film from Australian comedian Mick Molloy|
|        translate English to French: I like Errol Flynn; I like biographies and I like action movies|
|translate English to French: I watched the first 10 minutes of this show I think I'm gonna barf n...|
|translate English to French: Sherman Hemsley was great in the Jeffersons

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):

- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays 
- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function

In [0]:
def predict_batch_fn():
    import numpy as np
    import torch
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    from pyspark import TaskContext

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Initializing model on worker {TaskContext.get().partitionId()}, device {device}.")
    model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
    tokenizer = T5Tokenizer.from_pretrained("t5-small")

    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()
        inputs = tokenizer(flattened, 
                           padding=True,
                           return_tensors="pt").to(device)
        outputs = model.generate(input_ids=inputs["input_ids"],
                                 attention_mask=inputs["attention_mask"],
                                 max_length=128)
        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])
        print("predict: {}".format(len(flattened)))
        return string_outputs
    
    return predict

In [0]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=32)

In [0]:
%%time
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()

15.901144027709961 seconds


In [0]:
%%time
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()

5.436015844345093 seconds


In [0]:
%%time
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()

5.352211952209473 seconds


In [0]:
preds.show(truncate=50)

+--------------------------------------------------+--------------------------------------------------+
|                                             input|                                             preds|
+--------------------------------------------------+--------------------------------------------------+
|translate English to French: This really is by ...|Cette vidéo est vraiment le pire que je n'ai ja...|
|translate English to French: Basically, this wa...|En fait, il s'agissait évidemment de matériel p...|
|translate English to French: From the title, th...|partir du titre, de la ligne de balise, du résu...|
|translate English to French: This is the first ...|C'est le premier long métrage du comédien austr...|
|translate English to French: I like Errol Flynn...|Je m'aime Errol Flynn; je m'aime les biographie...|
|translate English to French: I watched the firs...|Alors, j'ai vu les 10 premières minutes de ce s...|
|translate English to French: Sherman Hemsley wa...|Sherman Hems

Let's try English to German:

In [None]:
input_df2 = df.select(preprocess(col("text"), "translate English to German: ").alias("input")).cache()

In [None]:
%%time
# first pass caches model/fn
preds = df2.withColumn("preds", generate(struct("input")))
result = preds.collect()

In [None]:
%%time
preds = df2.withColumn("preds", generate("input"))
result = preds.collect()

In [None]:
%%time
preds = df2.withColumn("preds", generate(col("input")))
result = preds.collect()

In [None]:
preds.show(truncate=50)

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.  

<img src="../images/spark-pytriton.png" alt="drawing" width="700"/>

In [0]:
from functools import partial

In [0]:
def triton_server():
    import signal
    import numpy as np
    import torch
    from transformers import T5Tokenizer, T5ForConditionalGeneration
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton
    from pyspark import TaskContext

    with Triton() as triton:
        print(f"SERVER: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.")
        tokenizer = T5Tokenizer.from_pretrained("t5-small")
        model = T5ForConditionalGeneration.from_pretrained("t5-small")
        
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"SERVER: Using {DEVICE} device.")
        model = model.to(DEVICE)

        @batch
        def _infer_fn(**inputs):
            sentences = np.squeeze(inputs["text"]).tolist()
            print(f"SERVER: Received batch of size {len(sentences)}")
            decoded_sentences = [s.decode("utf-8") for s in sentences]
            inputs = tokenizer(decoded_sentences,
                            padding=True,
                            return_tensors="pt").to(DEVICE)
            output_ids = model.generate(input_ids=inputs["input_ids"],
                                        attention_mask=inputs["attention_mask"],
                                        max_length=128)
            outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])
            return {
                "translations": outputs,
            }

        triton.bind(
            model_name="ConditionalGeneration",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="text", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="translations", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=64,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def stop_triton(signum, frame):
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

def start_triton(url, model_name):
    import socket
    from multiprocessing import Process
    from pytriton.client import ModelClient

    hostname = socket.gethostname()
    process = Process(target=triton_server)
    process.start()

    client = ModelClient(url, model_name)
    ready = False
    while not ready:
        try:
            client.wait_for_server(5)
            ready = True
        except Exception as e:
            print(f"Waiting for server to be ready: {e}")
    
    return [(hostname, process.pid)]

#### Start Triton servers

To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU.  
Stage-level scheudling requires a value for `spark.executor.cores`, and requires that `spark.executor.resource.gpu.amount` <= 1.

In [0]:
def _use_stage_level_scheduling(spark, rdd):

    if spark.version < "3.4.0":
        raise Exception("Stage-level scheduling is not supported in Spark < 3.4.0")

    executor_cores = spark.conf.get("spark.executor.cores")
    assert executor_cores is not None, "spark.executor.cores is not set"
    executor_gpus = spark.conf.get("spark.executor.resource.gpu.amount")
    assert executor_gpus is not None and int(executor_gpus) <= 1, "spark.executor.resource.gpu.amount must be set and <= 1"

    from pyspark.resource.profile import ResourceProfileBuilder
    from pyspark.resource.requests import TaskResourceRequests

    # each training task requires cpu cores > total executor cores/2 which can
    # ensure each training task be sent to different executor.
    #
    # Please note that we can't set task_cores to the value which is smaller than total executor cores/2
    # because only task_gpus can't ensure the tasks be sent to different executor even task_gpus=1.0
    #
    # If spark-rapids enabled. we don't allow other ETL task running alongside training task to avoid OOM
    spark_plugins = spark.conf.get("spark.plugins", " ")
    assert spark_plugins is not None
    spark_rapids_sql_enabled = spark.conf.get("spark.rapids.sql.enabled", "true")
    assert spark_rapids_sql_enabled is not None

    task_cores = (
        int(executor_cores)
        if "com.nvidia.spark.SQLPlugin" in spark_plugins
        and "true" == spark_rapids_sql_enabled.lower()
        else (int(executor_cores) // 2) + 1
    )
    # task_gpus means how many slots per gpu address the task requires,
    # it does mean how many gpus it would like to require, so it can be any value of (0, 0.5] or 1.
    task_gpus = 1.0

    treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
    rp = ResourceProfileBuilder().require(treqs).build

    print(f"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})")

    return rdd.withResources(rp)

#####  Specify the number of nodes in the cluster.
Following the README, the example standalone cluster uses 1 node.   
The example Databricks/Dataproc cluster scripts use 8 nodes. 

In [None]:
num_nodes = 8  # Change based on cluster setup

In [0]:
url = "localhost"
model_name = "ConditionalGeneration"

sc = spark.sparkContext
nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)

Reqesting stage-level resources: (cores=8, gpu=1.0)


In [0]:
pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(url, model_name)).collectAsMap()
print("Triton Server PIDs:\n", json.dumps(pids, indent=4))

Triton Server PIDs:
 {
    "1029-000529-b00wdfos-10-2-128-17": 5307,
    "1029-000529-b00wdfos-10-2-128-13": 5332,
    "1029-000529-b00wdfos-10-2-128-20": 5331,
    "1029-000529-b00wdfos-10-2-128-12": 5320,
    "1029-000529-b00wdfos-10-2-128-22": 5334,
    "1029-000529-b00wdfos-10-2-128-18": 5315,
    "1029-000529-b00wdfos-10-2-128-16": 5316,
    "1029-000529-b00wdfos-10-2-128-14": 5335
}


In [0]:
def triton_fn(url, model_name, init_timeout_s):
    import numpy as np
    from pytriton.client import ModelClient

    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, init_timeout_s=init_timeout_s) as client:
            flattened = np.squeeze(inputs).tolist() 
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["translations"], -1)
            return result_data
        
    return infer_batch

#### Load and preprocess DataFrame

In [0]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [0]:
df = spark.read.parquet(data_path).limit(2048).repartition(64)

In [0]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()

In [0]:
generate = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name, init_timeout_s=600),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=32)

In [0]:
%%time
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()

CPU times: user 13.6 ms, sys: 14.4 ms, total: 28 ms
Wall time: 9.87 s


In [0]:
%%time
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()

CPU times: user 14.4 ms, sys: 7.72 ms, total: 22.2 ms
Wall time: 7.53 s


In [0]:
%%time
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()

CPU times: user 17.1 ms, sys: 5.08 ms, total: 22.2 ms
Wall time: 7.48 s


In [0]:
preds.show(truncate=50)

+--------------------------------------------------+--------------------------------------------------+
|                                             input|                                             preds|
+--------------------------------------------------+--------------------------------------------------+
|translate English to French: I've been a fan of...|Je suis un fan de Jim Henson et de ses personna...|
|translate English to French: This is one of the...|C'est l'une des pires mini-séries que je n'ai j...|
|translate English to French: Countenance! Antoi...|Antoine Monot, dans une impersonation de l'albu...|
|translate English to French: I believe that thi...|Je crois que c'est l'une des meilleures représe...|
|translate English to French: I Feel the Niiiiii...|Je suis en train de s'en tenir à la somme de l'...|
|translate English to French: If you don't like ...|Si vous n'aimez pas de mauvais actes, de mauvai...|
|translate English to French: This whole movie i...|Ce film tout

#### Shut down servers

In [0]:
def stop_triton(pids):
    import os
    import socket
    import signal
    import time 
    
    hostname = socket.gethostname()
    pid = pids.get(hostname, None)
    assert pid is not None, "Could not find pid for {hostname}"
    os.kill(pid, signal.SIGTERM)
    time.sleep(7)
    
    for _ in range(5):
        try:
            os.kill(pid, 0)
        except OSError:
            return [True]
        time.sleep(3)

    return [False]

shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)
shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()

Reqesting stage-level resources: (cores=8, gpu=1.0)


[True, True, True, True, True, True, True, True]

In [None]:
spark.stop()