# PySpark DL Inference on Databricks
### Conditional generation with Huggingface

In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation.  
From: https://huggingface.co/docs/transformers/model_doc/t5

In [0]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, struct
from pyspark.ml.functions import predict_batch_udf
from pyspark import SparkConf

In [0]:
import os
import time
import pandas as pd
import datasets
from datasets import load_dataset
datasets.utils.logging.disable_progress_bar()
datasets.utils.logging.set_verbosity_error()

In [0]:
"""
(Optional): For large datasets, we can specify the Huggingface dataset cache directory to the cluster's local disk (or DBFS to persist after cluster termination), rather than the default '/' (ephemeral file system at instance root). The code below specifies local disk, which enables autoscaling up to 5TB.

For more info on the tradeoffs, see https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/hugging-face-dataset-download.html.
"""

# LOCAL_DISK_MOUNT = '/local_disk0'
# dbutils.fs.mkdirs(f"file://{LOCAL_DISK_MOUNT}/hf_cache")
# LOCAL_DISK_CACHE_DIR = f'{LOCAL_DISK_MOUNT}/hf_cache/'
# dataset = load_dataset("imdb, split="test", cache_dir=LOCAL_DISK_CACHE_DIR)


"\n(Optional): For large datasets, we can specify the Huggingface dataset cache directory to the cluster's local disk (or DBFS to persist after cluster termination), rather than the default '/' (ephemeral file system at instance root). The code below specifies local disk, which enables autoscaling up to 5TB.\n\nFor more info on the tradeoffs, see https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/hugging-face-dataset-download.html.\n"

Load the IMBD Movie Reviews dataset from Huggingface.

In [0]:
dataset = load_dataset("imdb", split="test")
dataset = dataset.to_pandas().drop(columns="label")

#### Create PySpark DataFrame

In [0]:
df = spark.createDataFrame(dataset).repartition(64)
df.schema

StructType([StructField('text', StringType(), True)])

In [0]:
df.count()

25000

In [0]:
df.take(1)

[Row(text="Stranded in Space (1972) MST3K version - a very not good TV movie pilot, for a never to be made series, in which an astronaut finds himself trapped on Earth's evil twin. Having a planet of identical size and mass orbiting in the same plane as the earth, but on the opposite side of the sun, is a well worn SF chestnut - the idea is over 2,000 years old, having been invented by the Ancient Greeks. In this version the Counter World is run as an Orwellian 'perfect' society. Where, for totally inexplicable reasons, everyone speaks English and drives late model American cars. After escaping from his prisonlike hospital, the disruptive Earthian is chased around Not Southern California by TV and bad movie stalwart Cameron Mitchell who, like his minions, wears double breasted suits and black polo neck jumpers - a stylishly evil combination which I fully intend to adopt if ever I become a totalitarian overlord. Our hero escapes several times before ending up gazing at the alien world's

In [0]:
data_path = "/FileStore/rishic/datasets/imdb_test"
df.write.mode("overwrite").parquet(data_path)

#### Load and preprocess DataFrame

We'll take the first sentence from each sample as our target for translation.

In [0]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [0]:
df = spark.read.parquet(data_path).limit(2048).repartition(64)
df.show(truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                                text|
+----------------------------------------------------------------------------------------------------+
|For all its many flaws, I'm inclined to be charitable towards "Thing". There is the nugget of an ...|
|Nothing will ever top KOMODO with the lovely Jill Hennessey as a shrink (!), but KvC ain't quite ...|
|I think this movie is my favorite movie. I am not sure why, but it is. Julia Duffy has been my fa...|
|I've been a devoted IMDB visitor for a few years. This is the movie that finally compelled me to ...|
|This movie is all about entertainment. Imagine your friends that you love spending time with, the...|
|Fatal Contact: Bird Flu in America: 3 out of 10: This movie is both funny and sad. The funny part...|
|My brother plays "Moose" in this film. Although most of his scenes were 

In [0]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()
input_df.show(truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                               input|
+----------------------------------------------------------------------------------------------------+
|translate English to French: This joins the endless line of corny, predictable 50's sci-fi shlock...|
|                                      translate English to French: I gave this film my rare 10 stars|
|                                               translate English to French: Co-scripted by William H|
|                          translate English to French: Well don't expect anything deep an meaningful|
|translate English to French: From the perspective of the hectic, contemporary world in which we l...|
|                                 translate English to French: I found this film extremely disturbing|
|                                     translate English to French: I have

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):

- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays 
- predict_batch_udf will automatically convert the Spark DataFrame columns into numpy input batches

In [0]:
def predict_batch_fn():
    import numpy as np
    import torch
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    from pyspark import TaskContext
    torch.cuda.empty_cache()

    print(f"Initializing model on worker {TaskContext.get().partitionId()}.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {device} device.")
    model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
    tokenizer = T5Tokenizer.from_pretrained("t5-small")

    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()
        inputs = tokenizer(flattened, 
                           padding=True,
                           return_tensors="pt").to(device)
        outputs = model.generate(input_ids=inputs["input_ids"],
                                 attention_mask=inputs["attention_mask"],
                                 max_length=128)
        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])
        print("predict: {}".format(len(flattened)))
        return string_outputs
    
    return predict

In [0]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=128)

In [0]:
start_time = time.time()
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()
print(f"{time.time() - start_time} seconds")

25.126676082611084 seconds


In [0]:
start_time = time.time()
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()
print(f"{time.time() - start_time} seconds")

4.936200141906738 seconds


In [0]:
start_time = time.time()
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()
print(f"{time.time() - start_time} seconds")

4.917703151702881 seconds


In [0]:
preds.show(truncate=50)

+--------------------------------------------------+--------------------------------------------------+
|                                             input|                                             preds|
+--------------------------------------------------+--------------------------------------------------+
|translate English to French: This joins the end...|Cette histoire s'ajoute à la ligne infinie de s...|
|translate English to French: I gave this film m...|           J'ai donné ce film mes 10 étoiles rares|
|translate English to French: Co-scripted by Wil...|                           Coscripté par William H|
|translate English to French: Well don't expect ...|       Ne soyez pas à l'aise avec rien d'important|
|translate English to French: From the perspecti...|Du point de vue du monde hebdomadaire et contem...|
|translate English to French: I found this film ...|          Je trouve ce film extrêmement inquiétant|
|translate English to French: I have seen this m...|            

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

In [0]:
from functools import partial

In [0]:
def triton_server():
    import signal
    import numpy as np
    import torch
    from transformers import T5Tokenizer, T5ForConditionalGeneration
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton
    from pyspark import TaskContext

    with Triton() as triton:
        print(f"TRITON: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.")
        tokenizer = T5Tokenizer.from_pretrained("t5-small")
        model = T5ForConditionalGeneration.from_pretrained("t5-small")
        
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"TRITON: Using {DEVICE} device.")
        model = model.to(DEVICE)

        @batch
        def _infer_fn(**inputs):
            sentences = np.squeeze(inputs["text"]).tolist()
            print(f"TRITON: Received batch of size {len(sentences)}")
            decoded_sentences = [s.decode("utf-8") for s in sentences]
            inputs = tokenizer(decoded_sentences,
                            padding=True,
                            return_tensors="pt").to(DEVICE)
            output_ids = model.generate(input_ids=inputs["input_ids"],
                                        attention_mask=inputs["attention_mask"],
                                        max_length=128)
            outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])
            return {
                "translations": outputs,
            }

        triton.bind(
            model_name="ConditionalGeneration",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="text", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="translations", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=256,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def stop_triton(signum, frame):
            print("TRITON: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, stop_triton)

        print("TRITON: Serving inference")
        triton.serve()

def start_triton(idx, url, model_name):
    from multiprocessing import Process
    from pytriton.client import ModelClient

    process = Process(target=triton_server)
    process.start()

    client = ModelClient(url, model_name)
    ready = False
    while not ready:
        try:
            client.wait_for_server(5)
            ready = True
        except Exception as e:
            print(f"Waiting for server to be ready: {e}")
    
    return [(idx, process.pid)]

To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU.  
Stage-level scheudling requires `spark.executor.cores` to be set, and requires that `spark.executor.resource.gpu.amount` = 1.

In [0]:
def _use_stage_level_scheduling(spark, rdd):

    if spark.version < "3.4.0":
        raise Exception("Stage-level scheduling is not supported in Spark < 3.4.0")

    executor_cores = spark.conf.get("spark.executor.cores")
    assert executor_cores is not None, "spark.executor.cores is not set"
    executor_gpus = spark.conf.get("spark.executor.resource.gpu.amount")
    assert executor_gpus is not None and int(executor_gpus) <= 1, "spark.executor.resource.gpu.amount must be set and <= 1"

    from pyspark.resource.profile import ResourceProfileBuilder
    from pyspark.resource.requests import TaskResourceRequests

    # each training task requires cpu cores > total executor cores/2 which can
    # ensure each training task be sent to different executor.
    #
    # Please note that we can't set task_cores to the value which is smaller than total executor cores/2
    # because only task_gpus can't ensure the tasks be sent to different executor even task_gpus=1.0
    #
    # If spark-rapids enabled. we don't allow other ETL task running alongside training task to avoid OOM
    spark_plugins = spark.conf.get("spark.plugins", " ")
    assert spark_plugins is not None
    spark_rapids_sql_enabled = spark.conf.get("spark.rapids.sql.enabled", "true")
    assert spark_rapids_sql_enabled is not None

    task_cores = (
        int(executor_cores)
        if "com.nvidia.spark.SQLPlugin" in spark_plugins
        and "true" == spark_rapids_sql_enabled.lower()
        else (int(executor_cores) // 2) + 1
    )
    # task_gpus means how many slots per gpu address the task requires,
    # it does mean how many gpus it would like to require, so it can be any value of (0, 0.5] or 1.
    task_gpus = 1.0

    treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
    rp = ResourceProfileBuilder().require(treqs).build

    print(f"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})")

    return rdd.withResources(rp)

In [0]:
# Start servers (8 node cluster)
num_nodes = 8
url = "localhost"
model_name = "ConditionalGeneration"

sc = spark.sparkContext
nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)

Training tasks require the resource(cores=8, gpu=1.0)


In [0]:
pids = nodeRDD.barrier().mapPartitionsWithIndex(lambda idx, _: start_triton(idx, url, model_name)).collectAsMap()
print("Triton Server PIDs:\n", pids)

Triton Server PIDs:
 {0: 2842, 1: 2807, 2: 2813, 3: 2783, 4: 2834, 5: 2783, 6: 2799, 7: 2802}


In [0]:
def triton_fn(url, model_name, init_timeout_s):
    import numpy as np
    from pytriton.client import ModelClient

    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, init_timeout_s=init_timeout_s) as client:
            flattened = np.squeeze(inputs).tolist() 
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["translations"], -1)
            return result_data
        
    return infer_batch

In [0]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [0]:
df = spark.read.parquet(data_path).limit(2048).repartition(64)

In [0]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()

In [0]:
generate = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name, init_timeout_s=600),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=128)

In [0]:
%%time
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()

CPU times: user 24 ms, sys: 3.51 ms, total: 27.5 ms
Wall time: 10.3 s


In [0]:
%%time
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()

CPU times: user 20.6 ms, sys: 584 µs, total: 21.2 ms
Wall time: 6.91 s


In [0]:
%%time
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()

CPU times: user 17.2 ms, sys: 1.88 ms, total: 19.1 ms
Wall time: 7.61 s


In [0]:
preds.show(truncate=50)

+--------------------------------------------------+--------------------------------------------------+
|                                             input|                                             preds|
+--------------------------------------------------+--------------------------------------------------+
|translate English to French: This joins the end...|Cette histoire s'ajoute à la ligne infinie de s...|
|translate English to French: I gave this film m...|           J'ai donné ce film mes 10 étoiles rares|
|translate English to French: Co-scripted by Wil...|                           Coscripté par William H|
|translate English to French: Well don't expect ...|       Ne soyez pas à l'aise avec rien d'important|
|translate English to French: From the perspecti...|Du point de vue du monde hebdomadaire et contem...|
|translate English to French: I found this film ...|          Je trouve ce film extrêmement inquiétant|
|translate English to French: I have seen this m...|            

In [0]:
def stop_triton(idx, pids):
    import os
    import signal
    import time 

    pid = pids[idx]

    num_retries = 5
    for _ in range(num_retries):
        try:
            os.kill(pid, signal.SIGTERM)
        except ProcessLookupError:
            return [True]
        time.sleep(5)

    return [False]

nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)
nodeRDD.mapPartitionsWithIndex(lambda idx, _: stop_triton(idx, pids)).collect()

Training tasks require the resource(cores=8, gpu=1.0)


[True, True, True, True, True, True, True, True]