<img src="https://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark LLM Inference: Qwen-2.5-14b Data Structuring

In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct), using open weights on Huggingface.

The Qwen-2.5-14b-instruct is an instruction-fine-tuned version of the Qwen-2.5-14b base model. We'll show how to use the model to prepare unstructured text data into a structured schema for downstream tasks.

**Note:** This example demonstrates **tensor parallelism**, which requires multiple GPUs per node. For standalone users, make sure to use a Spark worker with 2 GPUs. If you followed the Databricks or Dataproc instructions, the cluster configuration scripts will automatically acquire 2 GPUs per node.

Check the cluster environment to handle any platform-specific configurations.

In [1]:
import os

on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_IMAGE_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

In [2]:
# For cloud environments, load the model to the distributed file system.
if on_databricks:
    models_dir = "/dbfs/FileStore/spark-dl-models"
    dbutils.fs.mkdirs("/FileStore/spark-dl-models")
    model_path = f"{models_dir}/qwen2.5-14b"
elif on_dataproc:
    models_dir = "/mnt/gcs/spark-dl-models"
    os.mkdir(models_dir) if not os.path.exists(models_dir) else None
    model_path = f"{models_dir}/qwen2.5-14b"
else:
    model_path = os.path.abspath("qwen2.5-14b")

Download the model from huggingface hub.

In [3]:
from huggingface_hub import snapshot_download

model_path = snapshot_download(
    repo_id="Qwen/Qwen2.5-14B-Instruct",
    local_dir=model_path
)

Fetching 18 files:   0%|          | 0/18 [00:00<?, ?it/s]

## PySpark

In [4]:
import pandas as pd
from pyspark.sql.types import *
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, struct, length, lit, concat
from pyspark.ml.functions import predict_batch_udf

In [5]:
import os
import datasets
from datasets import load_dataset
datasets.disable_progress_bars()

#### Create Spark Session

For local standalone clusters, we'll connect to the cluster and create the Spark Session.  
For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).

In [6]:
conf = SparkConf()

if 'spark' not in globals():
    if on_standalone:
        import socket
        conda_env = os.environ.get("CONDA_PREFIX")
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
        
    conf.set("spark.executor.cores", "24")
    conf.set("spark.task.maxFailures", "1")
    conf.set("spark.task.resource.gpu.amount", "0.083333")
    conf.set("spark.executor.resource.gpu.amount", "2")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")

spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/04 23:26:12 INFO SparkEnv: Registering MapOutputTracker
25/03/04 23:26:12 INFO SparkEnv: Registering BlockManagerMaster
25/03/04 23:26:12 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
25/03/04 23:26:12 INFO SparkEnv: Registering OutputCommitCoordinator


#### Load and Preprocess DataFrame

Load the first 500 samples of the [Amazon Video Game Product Reviews dataset](https://huggingface.co/datasets/logankells/amazon_product_reviews_video_games) from Huggingface and store in a Spark Dataframe.

In [7]:
product_reviews_ds = load_dataset("LoganKells/amazon_product_reviews_video_games", split="train", streaming=True)
product_reviews_pds = pd.Series([sample["reviewText"] for sample in product_reviews_ds.take(500)])

Repo card metadata block was not found. Setting CardData to empty.


In [8]:
df = spark.createDataFrame(product_reviews_pds, schema=StringType())

In [9]:
df.show(5, truncate=100)

                                                                                

+----------------------------------------------------------------------------------------------------+
|                                                                                               value|
+----------------------------------------------------------------------------------------------------+
|Installing the game was a struggle (because of games for windows live bugs).Some championship rac...|
|If you like rally cars get this game you will have fun.It is more oriented to &#34;European marke...|
|1st shipment received a book instead of the game.2nd shipment got a FAKE one. Game arrived with a...|
|I had Dirt 2 on Xbox 360 and it was an okay game. I started playing games on my laptop and bought...|
|Overall this is a well done racing game, with very good graphics for its time period. My family h...|
+----------------------------------------------------------------------------------------------------+
only showing top 5 rows



Format each sample into the Qwen chat template, including a system prompt to guide generation.

In [10]:
system_prompt = """You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.
IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.
For each review, analyze and output EXACTLY this JSON structure:
{
  "primary_sentiment": [EXACTLY ONE OF: "positive", "negative", "neutral", "mixed"],
  "sentiment_score": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],
  "purchase_intention": [EXACTLY ONE OF: "will repurchase", "might repurchase", "will not repurchase", "recommends alternatives", "uncertain"]
}

Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.
"""

df = df.select(
    concat(
        lit("<|im_start|>system\n"),
        lit(system_prompt),
        lit("<|im_end|>\n<|im_start|>user\n"),
        lit("Analyze this review: "),
        col("value"),
        lit("<|im_end|>\n<|im_start|>assistant\n")
    ).alias("prompt")
)

In [11]:
print(df.take(1)[0].prompt)

[Stage 1:>                                                          (0 + 1) / 1]

<|im_start|>system
You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.
IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.
For each review, analyze and output EXACTLY this JSON structure:
{
  "primary_sentiment": [EXACTLY ONE OF: "positive", "negative", "neutral", "mixed"],
  "sentiment_score": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],
  "purchase_intention": [EXACTLY ONE OF: "will repurchase", "might repurchase", "will not repurchase", "recommends alternatives", "uncertain"]
}

Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.
<|im_end|>
<|im_start|>user
Analyze this review: Installing the game was a struggle (because of games for windows live bugs).Some championship races and cars can only be "unlocked" by buying them as an

                                                                                

In [12]:
data_path = "spark-dl-datasets/amazon_video_game_reviews"
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-datasets")
    data_path = "dbfs:/FileStore/" + data_path

df.write.mode("overwrite").parquet(data_path)

                                                                                

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-pytriton.png" alt="drawing" width="700"/>

In [13]:
from functools import partial

Import the helper class from pytriton_utils.py:

In [14]:
sc.addPyFile("pytriton_utils.py")

from pytriton_utils import TritonServerManager

Define the Triton Server function:

In [15]:
def triton_server(ports, model_path):
    import time
    import gc
    import signal
    import torch
    import numpy as np
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton, TritonConfig
    from pyspark import TaskContext
    from vllm import LLM, SamplingParams
    from vllm.distributed.parallel_state import destroy_model_parallel

    print(f"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.")
    sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=50)
    model = LLM(model=model_path, gpu_memory_utilization=0.8, tensor_parallel_size=2, max_model_len=6600)

    @batch
    def _infer_fn(**inputs):
        prompts = np.squeeze(inputs["prompts"]).tolist()
        print(f"SERVER: Received batch of size {len(prompts)}")
        decoded_prompts = [p.decode("utf-8") for p in prompts]
        outputs = model.generate(decoded_prompts, sampling_params)
        return {
            "outputs": np.array([o.outputs[0].text for o in outputs]).reshape(-1, 1)
        }

    workspace_path = f"/tmp/triton_{time.strftime('%m_%d_%M_%S')}"
    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])
    with Triton(config=triton_conf, workspace=workspace_path) as triton:
        triton.bind(
            model_name="qwen-2.5",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="prompts", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="outputs", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=64,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def _stop_triton(signum, frame):
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            # Cleanup vLLM distributed workers
            destroy_model_parallel()
            nonlocal model
            del model
            gc.collect()
            torch.cuda.empty_cache()
            torch.distributed.destroy_process_group()
            triton.stop()

        signal.signal(signal.SIGTERM, _stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

#### Start Triton servers

The `TritonClusterManager` will handle the lifecycle of Triton server instances across the Spark cluster:
- Find available ports for HTTP/gRPC/metrics
- Deploy a server on each node via stage-level scheduling
- Gracefully shutdown servers across nodes

In [16]:
model_name = "qwen-2.5"
server_manager = TritonServerManager(model_name=model_name, model_path=model_path)

In [17]:
# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}
server_manager.start_servers(triton_server, wait_retries=48)  # allow up to 4 minutes for model loading

2025-03-04 23:26:35,065 - INFO - Requesting stage-level resources: (cores=13, gpu=2.0)
2025-03-04 23:26:35,068 - INFO - Starting 2 servers.
                                                                                

{'spark-dl-inference-vllm-w-1': (47708, [7000, 7001, 7002]),
 'spark-dl-inference-vllm-w-0': (61199, [7000, 7001, 7002])}

#### Define client function

Get the hostname -> url mapping from the server manager:

In [18]:
host_to_grpc_url = server_manager.host_to_grpc_url  # or server_manager.host_to_http_url

In [19]:
def triton_fn(model_name, host_to_url):
    import json
    import socket
    import numpy as np
    from pytriton.client import ModelClient

    url = host_to_url[socket.gethostname()]
    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, inference_timeout_s=500) as client:
            flattened = np.squeeze(inputs).tolist()
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["outputs"], -1)
            # Load model json outputs into dictionaries
            result_dicts = [json.loads(o) for o in result_data]
            return result_dicts
        
    return infer_batch

In [20]:
generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),
                             return_type=StructType([
                                 StructField("primary_sentiment", StringType()),
                                 StructField("sentiment_score", IntegerType()),
                                 StructField("purchase_intention", StringType())
                             ]),
                             input_tensor_shapes=[[1]],
                             batch_size=64)

#### Load DataFrame

In [21]:
df = spark.read.parquet(data_path).repartition(16)

#### Run Inference

In [22]:
%%time
# first pass caches model/fn
preds = df.withColumn("outputs", generate(col("prompt"))).select("prompt", "outputs.*")
results = preds.collect()



CPU times: user 38.8 ms, sys: 8.33 ms, total: 47.2 ms
Wall time: 1min 2s


                                                                                

In [23]:
%%time
preds = df.withColumn("outputs", generate(col("prompt"))).select("prompt", "outputs.*")
results = preds.collect()



CPU times: user 23.8 ms, sys: 13.1 ms, total: 36.9 ms
Wall time: 59.3 s


                                                                                

In [30]:
preds.show(5, truncate=50)

[Stage 25:>                                                         (0 + 1) / 1]

+--------------------------------------------------+-----------------+---------------+-------------------+
|                                            prompt|primary_sentiment|sentiment_score| purchase_intention|
+--------------------------------------------------+-----------------+---------------+-------------------+
|<|im_start|>system\nYou are a specialized revie...|         positive|              9|    will repurchase|
|<|im_start|>system\nYou are a specialized revie...|         positive|              9|    will repurchase|
|<|im_start|>system\nYou are a specialized revie...|         positive|              8|    will repurchase|
|<|im_start|>system\nYou are a specialized revie...|         negative|              4|will not repurchase|
|<|im_start|>system\nYou are a specialized revie...|            mixed|              5|   might repurchase|
+--------------------------------------------------+-----------------+---------------+-------------------+
only showing top 5 rows



                                                                                

In [29]:
sample = results[0]
print("Review:", sample["prompt"])
print(f"Sentiment: {sample['primary_sentiment']}, Score: {sample['sentiment_score']}, Status: {sample['purchase_intention']}")

Review: <|im_start|>system
You are a specialized review analysis AI that categorizes product reviews into precise sentiment categories.
IMPORTANT: Your response must contain ONLY valid JSON and nothing else - no explanations, no additional text.
For each review, analyze and output EXACTLY this JSON structure:
{
  "primary_sentiment": [EXACTLY ONE OF: "positive", "negative", "neutral", "mixed"],
  "sentiment_score": [integer between 1-10, where 1 is extremely negative and 10 is extremely positive],
  "purchase_intention": [EXACTLY ONE OF: "will repurchase", "might repurchase", "will not repurchase", "recommends alternatives", "uncertain"]
}

Do not include any text before or after the JSON. The response should start with '{' and end with '}' with no trailing characters, comments, or explanations.
<|im_end|>
<|im_start|>user
Analyze this review: I have never played anything like this since. Everything from Sly  Racoon, to Ratchet and Clank, owe it to this.Wicked witch Gruntilda takes Ban

#### Shut down server on each executor

In [31]:
server_manager.stop_servers()

2025-03-04 23:34:40,993 - INFO - Requesting stage-level resources: (cores=13, gpu=2.0)
2025-03-04 23:34:49,882 - INFO - Sucessfully stopped 2 servers.                 


[True, True]

In [32]:
if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state
    spark.stop()