## PySpark

In [1]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, struct
from pyspark.ml.functions import predict_batch_udf
from pyspark import SparkConf

In [2]:
import os
import pandas as pd
from datasets import load_dataset

In [3]:
def get_rapids_jar():
    import os
    import requests

    SPARK_RAPIDS_VERSION = "24.10.0"
    SCALA_VERSION = "2.12"
    rapids_jar = f"rapids-4-spark_{SCALA_VERSION}-{SPARK_RAPIDS_VERSION}.jar"
    if not os.path.exists(rapids_jar):
        print("Downloading spark rapids jar")
        url = f"https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_{SCALA_VERSION}/{SPARK_RAPIDS_VERSION}/{rapids_jar}"
        response = requests.get(url)
        if response.status_code == 200:
            with open(rapids_jar, "wb") as f:
                f.write(response.content)
            print(f"File '{rapids_jar}' downloaded and saved successfully.")
        else:
            print(f"Failed to download the file. Status code: {response.status_code}")
    else:
        print("File already exists. Skipping download.")
    return rapids_jar

def initialize_spark(rapids_jar: str):
    '''
    If no active Spark session is found, initialize and configure a new one. 
    '''
    import socket
    hostname = socket.gethostname()
    conda_env = os.environ.get('CONDA_PREFIX')

    conf = SparkConf()
    conf.setMaster(f"spark://{hostname}:7077") # Assuming master is on host and default port. 
    conf.set("spark.task.maxFailures", "1")
    conf.set("spark.executor.memory", "8g")
    conf.set("spark.rpc.message.maxSize", "1024")
    conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "true")
    conf.set("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled", "false")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")
    conf.set("spark.rapids.ml.uvm.enabled", "true")
    conf.set("spark.task.resource.gpu.amount", "1")
    conf.set("spark.executor.resource.gpu.amount", "1")
    conf.set("spark.jars", rapids_jar)
    conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
    conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
    conf.set("spark.executorEnv.PYTHONPATH", rapids_jar)
    conf.set("spark.executorEnv.LD_LIBRARY_PATH", f"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH")
    conf.set("spark.rapids.memory.gpu.minAllocFraction", "0.0001")
    conf.set("spark.plugins", "com.nvidia.spark.SQLPlugin")
    conf.set("spark.locality.wait", "0s")
    conf.set("spark.sql.cache.serializer", "com.nvidia.spark.ParquetCachedBatchSerializer")
    conf.set("spark.rapids.memory.gpu.pooling.enabled", "false")
    conf.set("spark.sql.execution.sortBeforeRepartition", "false")
    conf.set("spark.rapids.sql.format.parquet.reader.type", "MULTITHREADED")
    conf.set("spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel", "20")
    conf.set("spark.rapids.sql.multiThreadedRead.numThreads", "20")
    conf.set("spark.rapids.sql.python.gpu.enabled", "true")
    conf.set("spark.rapids.memory.pinnedPool.size", "2G")
    conf.set("spark.python.daemon.module", "rapids.daemon")
    conf.set("spark.rapids.sql.batchSizeBytes", "512m")
    conf.set("spark.sql.adaptive.enabled", "false")
    conf.set("spark.sql.files.maxPartitionBytes", "512m")
    conf.set("spark.rapids.sql.concurrentGpuTasks", "4")
    conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1000")
    conf.set("spark.rapids.sql.explain", "NONE")
    
    spark = SparkSession.builder.appName("spark-dl-inference").config(conf=conf).getOrCreate()
    return spark

# Check if Spark session is already active, if not, initialize it
if 'spark' not in globals():
    print("No active Spark session found, initializing manually.")
    rapids_jar = os.environ.get('RAPIDS_JAR')
    if rapids_jar is None:
        rapids_jar = get_rapids_jar()
    spark = initialize_spark(rapids_jar)
else:
    print("Using existing Spark session.")

No active Spark session found, initializing manually.
File already exists. Skipping download.


24/10/28 21:00:53 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
24/10/28 21:00:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/10/28 21:00:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/28 21:00:54 WARN RapidsPluginUtils: RAPIDS Accelerator 24.10.0 using cudf 24.10.0, private revision bd4e99e18e20234ee0c54f95f4b0bfce18a6255e
24/10/28 21:00:54 WARN RapidsPluginUtils: RAPIDS Accelerator is enabled, to disable GPU support set `spark.rapids.sql.enabled` to false.


In [4]:
dataset = load_dataset("imdb", split="test").to_pandas().drop(columns="label")

#### Create PySpark DataFrame

In [5]:
df = spark.createDataFrame(dataset).repartition(8)
df.schema

StructType([StructField('text', StringType(), True)])

In [6]:
df.take(1)

24/10/28 21:01:00 WARN TaskSetManager: Stage 0 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

[Row(text="Isaac Florentine has made some of the best western Martial Arts action movies ever produced. In particular US Seals 2, Cold Harvest, Special Forces and Undisputed 2 are all action classics. You can tell Isaac has a real passion for the genre and his films are always eventful, creative and sharp affairs, with some of the best fight sequences an action fan could hope for. In particular he has found a muse with Scott Adkins, as talented an actor and action performer as you could hope for. This is borne out with Special Forces and Undisputed 2, but unfortunately The Shepherd just doesn't live up to their abilities.<br /><br />There is no doubt that JCVD looks better here fight-wise than he has done in years, especially in the fight he has (for pretty much no reason) in a prison cell, and in the final showdown with Scott, but look in his eyes. JCVD seems to be dead inside. There's nothing in his eyes at all. It's like he just doesn't care about anything throughout the whole film.

In [7]:
df.write.mode("overwrite").parquet("imdb_test")

24/10/28 21:01:01 WARN TaskSetManager: Stage 2 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

#### Load and preprocess DataFrame

In [8]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [9]:
df = spark.read.parquet("imdb_test").limit(512)
df.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                    text|
+------------------------------------------------------------------------------------------------------------------------+
|First off let me say, If you haven't enjoyed a Van Damme movie since bloodsport, you probably will not like this movi...|
|I first watched this movie back in the mid/late 80's, when I was a kid. We couldn't even get all the way through it. ...|
|Low budget horror movie. If you don't raise your expectations too high, you'll probably enjoy this little flick. Begi...|
|Four things intrigued me as to this film - firstly, it stars Carly Pope (of "Popular" fame), who is always a pleasure...|
|Beware, My Lovely (1952) Dir: Harry Horner <br /><br />Production: The Filmmakers/RKO Radio Pictures<br /><br />Credu...|
|Now I understan

In [10]:
input_df = df.select(preprocess(col("text"), "translate English to French: ").alias("input")).cache()
input_df.show(truncate=128)



+--------------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                           input|
+--------------------------------------------------------------------------------------------------------------------------------+
|translate English to French: First off let me say, If you haven't enjoyed a Van Damme movie since bloodsport, you probably wi...|
|                             translate English to French: I first watched this movie back in the mid/late 80's, when I was a kid|
|                                                                            translate English to French: Low budget horror movie|
|translate English to French: Four things intrigued me as to this film - firstly, it stars Carly Pope (of "Popular" fame), who...|
|translate English to French: Beware, My Lovely (1952) Dir: Harry Horner <br /><br 

                                                                                

## Inference using Spark DL API

In [11]:
def predict_batch_fn():
    import numpy as np
    import torch
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    
    print(f"Initializing model.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {device} device.")
    model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
    tokenizer = T5Tokenizer.from_pretrained("t5-small")

    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()
        inputs = tokenizer(flattened, 
                           padding=True,
                           return_tensors="pt").to(device)
        outputs = model.generate(input_ids=inputs["input_ids"],
                                 attention_mask=inputs["attention_mask"],
                                 max_length=128)
        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])
        print("predict: {}".format(len(flattened)))
        
        return string_outputs
    
    return predict

In [12]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=128)

In [13]:
%%time
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()

[Stage 10:>                                                         (0 + 1) / 1]

CPU times: user 5.14 ms, sys: 2.09 ms, total: 7.22 ms
Wall time: 5.86 s


                                                                                

In [14]:
%%time
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 12:>                                                         (0 + 1) / 1]

CPU times: user 3.5 ms, sys: 1.34 ms, total: 4.84 ms
Wall time: 3.79 s


                                                                                

In [15]:
%%time
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 14:>                                                         (0 + 1) / 1]

CPU times: user 2.87 ms, sys: 1.92 ms, total: 4.79 ms
Wall time: 3.94 s


                                                                                

In [16]:
preds.show(truncate=128)

+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                           input|                                                                                                                           preds|
+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
|translate English to French: First off let me say, If you haven't enjoyed a Van Damme movie since bloodsport, you probably wi...|Je voudrais tout d'abord dire que si vous n'avez pas eu d'intérêt pour un film Van Damme d

## Using Triton Inference Server

In [17]:
from functools import partial

In [18]:
num_executors = 1
url = "localhost"
model_name = "ConditionalGeneration"

In [19]:
def triton_server():
    import signal
    import numpy as np
    import torch
    from transformers import T5Tokenizer, T5ForConditionalGeneration
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton

    with Triton() as triton:
        print("Loading Conditional Generation model.")
        tokenizer = T5Tokenizer.from_pretrained("t5-small")
        model = T5ForConditionalGeneration.from_pretrained("t5-small")
        
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using {DEVICE} device.")
        model = model.to(DEVICE)

        @batch
        def _infer_fn(**inputs):
            sentences = np.squeeze(inputs["text"]).tolist()
            print(f"Received batch of size {len(sentences)}")
            decoded_sentences = [s.decode("utf-8") for s in sentences]
            inputs = tokenizer(decoded_sentences,
                            padding=True,
                            return_tensors="pt").to(DEVICE)
            output_ids = model.generate(input_ids=inputs["input_ids"],
                                        attention_mask=inputs["attention_mask"],
                                        max_length=128)
            outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])
            return {
                "translations": outputs,
            }

        triton.bind(
            model_name="ConditionalGeneration",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="text", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="translations", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=256,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        # Handle SIGTERM for graceful process shutdown
        def stop_triton(signum, frame):
            print("Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, stop_triton)

        print("Serving inference")
        triton.serve()

def start_triton(url, model_name):
    from multiprocessing import Process
    from pytriton.client import ModelClient

    process = Process(target=triton_server)
    process.start()

    client = ModelClient(url, model_name)
    ready = False
    while not ready:
        try:
            client.wait_for_server(5)
            ready = True
        except Exception as e:
            print(f"Waiting for server to be ready: {e}")
    
    return [process.pid]

sc = spark.sparkContext
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)
pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(url, model_name)).collect()
print(pids)

[Stage 17:>                                                         (0 + 1) / 1]

[3908791]


                                                                                

In [20]:
def triton_fn(url, model_name, init_timeout_s):
    import numpy as np
    from pytriton.client import ModelClient

    print(f"Connecting to model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, init_timeout_s=init_timeout_s) as client:
            flattened = np.squeeze(inputs).tolist() 
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["translations"], -1)
            return result_data
        
    return infer_batch

In [21]:
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [22]:
df = spark.read.parquet("imdb_test").limit(512)

In [23]:
input_df = df.select(preprocess(col("text"), "Translate English to French: ").alias("input")).cache()

In [24]:
generate = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name, init_timeout_s=600),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=128)

In [25]:
%%time
# first pass caches model/fn
preds = input_df.withColumn("preds", generate(struct("input")))
results = preds.collect()

[Stage 20:>                                                         (0 + 1) / 1]

CPU times: user 7.02 ms, sys: 3.49 ms, total: 10.5 ms
Wall time: 6.54 s


                                                                                

In [26]:
%%time
preds = input_df.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 22:>                                                         (0 + 1) / 1]

CPU times: user 3.53 ms, sys: 2.26 ms, total: 5.79 ms
Wall time: 3.79 s


                                                                                

In [27]:
%%time
preds = input_df.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 24:>                                                         (0 + 1) / 1]

CPU times: user 2.94 ms, sys: 1.47 ms, total: 4.41 ms
Wall time: 3.85 s


                                                                                

In [28]:
preds.show(truncate=128)

+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                           input|                                                                                                                           preds|
+--------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
|Translate English to French: First off let me say, If you haven't enjoyed a Van Damme movie since bloodsport, you probably wi...|Permettez-moi tout d'abord de dire que si vous n'avez pas joué un film Van Damme depuis le

In [29]:
def stop_triton(index, pids):
    import os
    import signal
    import time 

    pid = pids[index]
    os.kill(pid, signal.SIGTERM)

    for _ in range(5):
        try:
            os.kill(pid, 0)
        except ProcessLookupError:
            return [True]
        time.sleep(5)

    return [False]

nodeRDD.barrier().mapPartitionsWithIndex(lambda index, _: stop_triton(index, pids)).collect()

                                                                                

[True]