# Databricks: PySpark DL Inference
### Conditional generation with Huggingface

In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation.  
From: https://huggingface.co/docs/transformers/model_doc/t5

In [0]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, struct
from pyspark.ml.functions import predict_batch_udf
from pyspark import SparkConf

In [0]:
import os
import time
import pandas as pd
import datasets
from datasets import load_dataset
datasets.utils.logging.disable_progress_bar()
datasets.utils.logging.set_verbosity_error()

In [0]:
"""
(Optional): For large datasets, we can specify the Huggingface dataset cache directory to the cluster's local disk (or DBFS to persist after cluster termination), rather than the default '/' (ephemeral file system at instance root). The code below specifies local disk, which enables autoscaling up to 5TB.

For more info on the tradeoffs, see https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/hugging-face-dataset-download.html.
"""

# LOCAL_DISK_MOUNT = '/local_disk0'
# dbutils.fs.mkdirs(f"file://{LOCAL_DISK_MOUNT}/hf_cache")
# LOCAL_DISK_CACHE_DIR = f'{LOCAL_DISK_MOUNT}/hf_cache/'
# dataset = load_dataset("imdb, split="test", cache_dir=LOCAL_DISK_CACHE_DIR)

"\n(Optional): For large datasets, we can specify the Huggingface dataset cache directory to the cluster's local disk or DBFS, rather than the default '/' (ephemeral file system at instance root). The code below specifies local disk, which enables autoscaling up to 5TB.\n\nFor more info on the tradeoffs, see https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/hugging-face-dataset-download.html.\n"

#### Load the IMDB movie reviews dataset.

In [0]:
dataset = load_dataset("imdb", split="test")
dataset = dataset.to_pandas().drop(columns="label")

#### Create PySpark DataFrame

In [0]:
df = spark.createDataFrame(dataset).repartition(64)
df.schema

StructType([StructField('text', StringType(), True)])

In [0]:
df.count()

25000

In [0]:
df.take(1)

[Row(text="Frankly, after Cotton club and Unfaithful, it was kind of embarrassing to watch Lane and Gere in this film, because it is BAD. The acting was bad, the dialogs were extremely shallow and insincere. It was well shot, but, then again, it is a big budget movie. It was too predictable, even for a chick flick. I even knew from the beginning that he was going to die in the end, the only thing I didn't know was how. Too politically correct. Very disappointing. The only thing really worth watching was the scenery and the house, because it is beautiful. But, if you want that, watch National geographic. I love Lane, but I've never seen her in a movie this lousy. As far as Gere goes, he's a good actor, but he had movies like this, so I'm not surprised. An hour and a half I wish I could bring back.")]

In [0]:
data_path = "/FileStore/rishic/datasets/imdb_test"
df.write.mode("overwrite").parquet(data_path)

#### Load and preprocess DataFrame

In [0]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text]) # Take first sentence.
    return _preprocess(text)

In [0]:
df = spark.read.parquet(data_path).limit(2046).repartition(64)
df.show(truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                                text|
+----------------------------------------------------------------------------------------------------+
|I stumbled across this movie late at night on TV. My brother and I could not stop laughing at how...|
|Was a college acting class exercise filmed and released as a movie? The formulaic posturing and s...|
|This is a better-than-average entry in the Saint series - It holds your interest and, as mysterie...|
|There was a lot about Little Vera that was strange to me. All in all I did enjoy this movie, but ...|
|I never dreamed when I started watching this DVD that I would be totally mesmerized by it within ...|
|I'm a Boorman fan, but this is arguably his least successful film. Comedy has never been his stro...|
|I wasn't sure about getting this movie on DVD because I really do have s

In [0]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()
input_df.show(truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                               input|
+----------------------------------------------------------------------------------------------------+
|translate English to French: I loved the first "American Graffiti" with all my heart and soul tha...|
|translate English to French: Firstly this has nothing to do with the much better 18 weapons of Ku...|
|translate English to French: But perhaps you have to have grown up in the 80's to truly appreciat...|
|translate English to French: Take a subject I didn't know much about and make it exciting, why do...|
|      translate English to French: This is the worst film I have ever seen, so bad it is astonishing|
|translate English to French: Margaret Mitchell spins in her grave every time somebody watches thi...|
|                                     translate English to French: I have

## Inference using Spark DL API

Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf).

In [0]:
def predict_batch_fn():
    import numpy as np
    import torch
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    from pyspark import TaskContext
    torch.cuda.empty_cache()

    print(f"Initializing model on worker {TaskContext.get().partitionId()}.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {device} device.")
    model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
    tokenizer = T5Tokenizer.from_pretrained("t5-small")

    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()
        inputs = tokenizer(flattened, 
                           padding=True,
                           return_tensors="pt").to(device)
        outputs = model.generate(input_ids=inputs["input_ids"],
                                 attention_mask=inputs["attention_mask"],
                                 max_length=128)
        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])
        print("predict: {}".format(len(flattened)))
        return string_outputs
    
    return predict

In [0]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=128)

In [0]:
start_time = time.time()
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()
print(f"{time.time() - start_time} seconds")

26.08012843132019 seconds


In [0]:
start_time = time.time()
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()
print(f"{time.time() - start_time} seconds")

5.499233961105347 seconds


In [0]:
start_time = time.time()
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()
print(f"{time.time() - start_time} seconds")

5.471662521362305 seconds


In [0]:
preds.show(truncate=100)

+----------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
|                                                                                               input|                                                                                               preds|
+----------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
|translate English to French: I loved the first "American Graffiti" with all my heart and soul tha...|J'ai aimé le premier "American Graffiti" avec tout mon cur et tout mon âme que j'ai considéré com...|
|translate English to French: Firstly this has nothing to do with the much better 18 weapons of Ku...|Premièrement, cela n'a rien à voir avec les 18 armes bien meilleurs de Kung Fu en 

## Using Triton Inference Server

The Triton Inference Server is launched in a separate process on each node.   
We use [PyTriton](https://github.com/triton-inference-server/pytriton), which provides a Python API to handle client/server communication.

In [0]:
from functools import partial

In [0]:
def triton_server():
    import signal
    import numpy as np
    import torch
    from transformers import T5Tokenizer, T5ForConditionalGeneration
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton
    from pyspark import TaskContext

    with Triton() as triton:
        print(f"Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.")
        tokenizer = T5Tokenizer.from_pretrained("t5-small")
        model = T5ForConditionalGeneration.from_pretrained("t5-small")
        
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using {DEVICE} device.")
        model = model.to(DEVICE)

        @batch
        def _infer_fn(**inputs):
            sentences = np.squeeze(inputs["text"]).tolist()
            print(f"Received batch of size {len(sentences)}")
            decoded_sentences = [s.decode("utf-8") for s in sentences]
            inputs = tokenizer(decoded_sentences,
                            padding=True,
                            return_tensors="pt").to(DEVICE)
            output_ids = model.generate(input_ids=inputs["input_ids"],
                                        attention_mask=inputs["attention_mask"],
                                        max_length=128)
            outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])
            return {
                "translations": outputs,
            }

        triton.bind(
            model_name="ConditionalGeneration",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="text", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="translations", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=256,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def stop_triton(signum, frame):
            print("Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, stop_triton)

        print("Serving inference")
        triton.serve()

def start_triton(idx, url, model_name):
    from multiprocessing import Process
    from pytriton.client import ModelClient

    process = Process(target=triton_server)
    process.start()

    client = ModelClient(url, model_name)
    ready = False
    while not ready:
        try:
            client.wait_for_server(5)
            ready = True
        except Exception as e:
            print(f"Waiting for server to be ready: {e}")
    
    return [(idx, process.pid)]

# Start servers
num_nodes = 8
url = "localhost"
model_name = "ConditionalGeneration"

sc = spark.sparkContext
nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)
pids = nodeRDD.barrier().mapPartitionsWithIndex(lambda idx, _: start_triton(idx, url, model_name)).collectAsMap()
print("Triton Server PIDs:\n", pids)

Triton Server PIDs:
 {0: 3151, 1: 3155, 2: 3149, 3: 3133, 4: 3134, 5: 3131, 6: 3129, 7: 3140}


In [0]:
def triton_fn(url, model_name, init_timeout_s):
    import numpy as np
    from pytriton.client import ModelClient

    print(f"Connecting to model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, init_timeout_s=init_timeout_s) as client:
            flattened = np.squeeze(inputs).tolist() 
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["translations"], -1)
            return result_data
        
    return infer_batch

In [0]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [0]:
df = spark.read.parquet(data_path).limit(2046).repartition(64)

In [0]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()

In [0]:
generate = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name, init_timeout_s=600),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=128)

In [0]:
%%time
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()

CPU times: user 12.4 ms, sys: 14.3 ms, total: 26.7 ms
Wall time: 9.77 s


In [0]:
%%time
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()

CPU times: user 6.7 ms, sys: 13.3 ms, total: 20 ms
Wall time: 6.47 s


In [0]:
%%time
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()

CPU times: user 17.7 ms, sys: 2.28 ms, total: 19.9 ms
Wall time: 6.65 s


In [0]:
preds.show(truncate=100)

+----------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
|                                                                                               input|                                                                                               preds|
+----------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------+
|translate English to French: This really is by far the worst movie I've ever seen in my whole lif...|         Cette vidéo est vraiment le pire que je n'ai jamais vu dans ma vie (je me rapproche de 47)!|
|                               translate English to French: I thought i could see something good but|                                          Je pensais que je pouvais voir quelque c

In [0]:
def stop_triton(idx, pids):
    import os
    import signal
    import time 

    pid = pids[idx]

    num_retries = 5
    for _ in range(num_retries):
        try:
            os.kill(pid, signal.SIGTERM)
        except ProcessLookupError:
            return [True]
        time.sleep(5)

    return [False]

nodeRDD.barrier().mapPartitionsWithIndex(lambda idx, _: stop_triton(idx, pids)).collect()

[True, True, True, True, True, True, True, True]