<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark LLM Inference: Qwen-2.5 Text Summarization

In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct), using open weights on Huggingface.

The Qwen-2.5-7b-instruct is an instruction-fine-tuned version of the Qwen-2.5-7b base model. We'll show how to use the model to perform text summarization.

**Note:** Running this model on GPU with 16-bit precision requires **~16GB** of GPU RAM. Make sure your instances have sufficient GPU capacity.

In [1]:
# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.
# See (https://github.com/huggingface/transformers/issues/5486) for more info. 
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

# vLLM does CUDA init at import time. Forking will try to re-initialize CUDA if vLLM was imported before and throw an error.
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

Check the cluster environment to handle any platform-specific configurations.

In [2]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_IMAGE_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

In [3]:
# For cloud environments, load the model to the distributed file system.
if on_databricks:
    models_dir = "/dbfs/FileStore/spark-dl-models"
    dbutils.fs.mkdirs("/FileStore/spark-dl-models")
    model_path = f"{models_dir}/qwen-2.5-7b"
elif on_dataproc:
    models_dir = "/mnt/gcs/spark-dl-models"
    os.mkdir(models_dir) if not os.path.exists(models_dir) else None
    model_path = f"{models_dir}/qwen-2.5-7b"
else:
    model_path = os.path.abspath("qwen-2.5-7b")

Download the model from huggingface hub.

In [4]:
from huggingface_hub import snapshot_download

model_path = snapshot_download(
    repo_id="Qwen/Qwen2.5-7B-Instruct",
    local_dir=model_path
)

Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

There are currently some hanging issues with vLLM's `torch.compile` on Databricks, which we are working to resolve. For now we will enforce eager mode on Databricks, which disables compilation at some performance cost.

In [5]:
enforce_eager = True if on_databricks else False

## Warmup: Running locally

**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section.

In [None]:
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=128)
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
llm = LLM(model=model_path, gpu_memory_utilization=0.95, max_model_len=6600, enforce_eager=enforce_eager)

In [7]:
system_prompt = {
    "role": "system",
    "content": "You are a knowledgeable AI assistant that provides accurate answers to questions."
}

queries = [
    "What does CUDA stand for?",
    "In one sentence, what's the difference between a CPU and a GPU?",
    "What's the hottest planet in the solar system?"
]

prompts = [
    [
        system_prompt,
        {"role": "user", "content": query}
    ] for query in queries
]

text = tokenizer.apply_chat_template(
    prompts,
    tokenize=False,
    add_generation_prompt=True,
)

In [8]:
outputs = llm.generate(text, sampling_params=sampling_params)

Processed prompts: 100%|██████████| 3/3 [00:01<00:00,  1.76it/s, est. speed input: 63.83 toks/s, output: 100.14 toks/s]


In [9]:
for q, o in zip(queries, outputs):
    print(f"Q: {q}")
    print(f"A: {o.outputs[0].text}\n")

Q: What does CUDA stand for?
A: CUDA stands for Compute Unified Device Architecture. It is a parallel computing platform and application programming interface (API) model created by NVIDIA. CUDA allows software developers to use a CUDA-enabled graphics processing unit (GPU) for general purpose processing.

Q: In one sentence, what's the difference between a CPU and a GPU?
A: A CPU (Central Processing Unit) is designed for general-purpose processing and managing the overall operations of a computer, while a GPU (Graphics Processing Unit) is specialized for parallel processing tasks, particularly those related to rendering graphics and accelerating machine learning tasks.

Q: What's the hottest planet in the solar system?
A: The hottest planet in the solar system is Venus. Despite Mercury being closer to the Sun, Venus has a thick atmosphere that traps heat in a runaway version of the greenhouse effect, creating a much hotter surface temperature than Mercury. The average surface temperat

In [None]:
import gc
import torch
del llm
torch.distributed.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()

## PySpark

In [11]:
import pandas as pd
from pyspark.sql.types import *
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, struct, length, lit, concat
from pyspark.ml.functions import predict_batch_udf

In [12]:
import os
import datasets
from datasets import load_dataset
datasets.disable_progress_bars()

[2025-03-24 11:37:46] INFO config.py:54: PyTorch version 2.6.0 available.


#### Create Spark Session

For local standalone clusters, we'll connect to the cluster and create the Spark Session.  
For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).

In [13]:
conf = SparkConf()

if 'spark' not in globals():
    if on_standalone:
        import socket
        conda_env = os.environ.get("CONDA_PREFIX")
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")

    conf.set("spark.executor.cores", "8")
    conf.set("spark.task.maxFailures", "1")
    conf.set("spark.task.resource.gpu.amount", "0.125")
    conf.set("spark.executor.resource.gpu.amount", "1")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")

spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

25/03/24 11:37:47 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
25/03/24 11:37:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/24 11:37:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


#### Load and Preprocess DataFrame

Load the first 500 samples of the [ML ArXiv dataset](https://huggingface.co/datasets/CShorten/ML-ArXiv-Papers) from Huggingface and store in a Spark Dataframe.

In [14]:
ml_arxiv_dataset = load_dataset("CShorten/ML-ArXiv-Papers", split="train", streaming=True)
ml_arxiv_pds = pd.Series([sample["abstract"] for sample in ml_arxiv_dataset.take(500)])

In [15]:
df = spark.createDataFrame(ml_arxiv_pds, schema=StringType())

In [16]:
df.show(5, truncate=100)

                                                                                

+----------------------------------------------------------------------------------------------------+
|                                                                                               value|
+----------------------------------------------------------------------------------------------------+
|  The problem of statistical learning is to construct a predictor of a random\nvariable $Y$ as a ...|
|  In a sensor network, in practice, the communication among sensors is subject\nto:(1) errors or ...|
|  The on-line shortest path problem is considered under various models of\npartial monitoring. Gi...|
|  Ordinal regression is an important type of learning, which has properties of\nboth classificati...|
|  This paper uncovers and explores the close relationship between Monte Carlo\nOptimization of a ...|
+----------------------------------------------------------------------------------------------------+
only showing top 5 rows



Format each sample into the chat template, including a system prompt to guide generation.

In [17]:
system_prompt = '''You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary 
of a research abstract that captures the main objective, methodology, and key findings, using clear 
language while preserving technical accuracy and quantitative results.'''

df = df.select(
    concat(
        lit("<|im_start|>system\n"),
        lit(system_prompt),
        lit("<|im_end|>\n<|im_start|>user\n"),
        col("value"),
        lit("<|im_end|>\n<|im_start|>assistant\n")
    ).alias("prompt")
)

In [18]:
print(df.take(1)[0].prompt)

<|im_start|>system
You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary 
of a research abstract that captures the main objective, methodology, and key findings, using clear 
language while preserving technical accuracy and quantitative results.<|im_end|>
<|im_start|>user
  The problem of statistical learning is to construct a predictor of a random
variable $Y$ as a function of a related random variable $X$ on the basis of an
i.i.d. training sample from the joint distribution of $(X,Y)$. Allowable
predictors are drawn from some specified class, and the goal is to approach
asymptotically the performance (expected loss) of the best predictor in the
class. We consider the setting in which one has perfect observation of the
$X$-part of the sample, while the $Y$-part has to be communicated at some
finite bit rate. The encoding of the $Y$-values is allowed to depend on the
$X$-values. Under suitable regularity conditions on the admissible predictors,
the underlying

In [19]:
data_path = "spark-dl-datasets/arxiv_abstracts"
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-datasets")
    data_path = "dbfs:/FileStore/" + data_path

df.write.mode("overwrite").parquet(data_path)

## Using vLLM Server
In this section, we demonstrate integration with [vLLM Serving](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html), an open-source server with an OpenAI-compatible completions endpoint for LLMs.  

The process looks like this:
- Distribute a server startup task across the Spark cluster, instructing each node to launch a vLLM server process.
- Define a vLLM inference function, which sends inference request to the local server on a given node.
- Wrap the vLLM inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the vLLM server processes on each node.

<img src="../images/spark-server.png" alt="drawing" width="700"/>

In [20]:
from functools import partial

Import the helper class from server_utils.py:

In [21]:
sc.addPyFile("server_utils.py")

from server_utils import VLLMServerManager

#### Start vLLM servers

The `VLLMServerManager` will handle the lifecycle of vLLM server instances across the Spark cluster:
- Find available ports for HTTP
- Deploy a server on each node via stage-level scheduling
- Gracefully shutdown servers across nodes

In [22]:
model_name = "qwen-2.5-7b"
server_manager = VLLMServerManager(model_name=model_name, model_path=model_path)

                                                                                

You can pass any of the supported [vLLM serve CLI arguments](https://docs.vllm.ai/en/stable/serving/openai_compatible_server.html#vllm-serve) as key-word arguments when starting the servers. Note that this can take some time, as it includes loading the model from disk, Torch compilation, and capturing CUDA graphs.

In [23]:
server_manager.start_servers(gpu_memory_utilization=0.95,
                             max_model_len=6600,
                             task="generate",
                             enforce_eager=enforce_eager,
                             wait_retries=60)

[2025-03-24 11:37:57] INFO server_utils.py:359: Requesting stage-level resources: (cores=5, gpu=1.0)
[2025-03-24 11:37:57] INFO server_utils.py:390: Starting 1 VLLM servers.


                                                                                

{'cb4ae00-lcedt': (4022579, [7000])}

#### Define client function

Get the hostname -> url mapping from the server manager:

In [24]:
host_to_http_url = server_manager.host_to_http_url

Define the vLLM inference function, which returns a predict function for batch inference through the server:

In [25]:
def vllm_fn(model_name, host_to_url):
    import socket
    import numpy as np
    import requests

    url = host_to_url[socket.gethostname()]
    
    def predict(inputs):
        response = requests.post(
            f"{url}/v1/completions",
            json={
                "model": model_name,
                "prompt": inputs.tolist(),
                "max_tokens": 128,
                "temperature": 0.7,
                "top_p": 0.8,
                "repetition_penalty": 1.05,
            }
        )
        return np.array([r["text"] for r in response.json()["choices"]])
    
    return predict

In [26]:
generate = predict_batch_udf(partial(vllm_fn, model_name=model_name, host_to_url=host_to_http_url),
                             return_type=StringType(),
                             batch_size=32)

#### Load DataFrame

We'll parallelize over a small set of prompts for demonstration.

In [27]:
df = spark.read.parquet(data_path).limit(256).repartition(8)

#### Run Inference

In [28]:
%%time
# first pass caches model/fn and does JIT compilation
preds = df.withColumn("outputs", generate(col("prompt")))
results = preds.collect()



CPU times: user 7.53 ms, sys: 2.19 ms, total: 9.72 ms
Wall time: 13.9 s


                                                                                

In [29]:
%%time
preds = df.withColumn("outputs", generate(col("prompt")))
results = preds.collect()



CPU times: user 10.7 ms, sys: 3.65 ms, total: 14.3 ms
Wall time: 6.26 s


                                                                                

Sample output:

In [30]:
print(f"Q: {results[0].prompt} \n")
print(f"A: {results[0].outputs} \n")

Q: <|im_start|>system
You are a knowledgeable AI assistant. Your job is to create a 1 sentence summary 
of a research abstract that captures the main objective, methodology, and key findings, using clear 
language while preserving technical accuracy and quantitative results.<|im_end|>
<|im_start|>user
  Images can be segmented by first using a classifier to predict an affinity
graph that reflects the degree to which image pixels must be grouped together
and then partitioning the graph to yield a segmentation. Machine learning has
been applied to the affinity classifier to produce affinity graphs that are
good in the sense of minimizing edge misclassification rates. However, this
error measure is only indirectly related to the quality of segmentations
produced by ultimately partitioning the affinity graph. We present the first
machine learning algorithm for training a classifier to produce affinity graphs
that are good in the sense of producing segmentations that directly minimize
the R

#### Shut down server on each executor

In [31]:
server_manager.stop_servers()

[2025-03-24 11:38:49] INFO server_utils.py:359: Requesting stage-level resources: (cores=5, gpu=1.0)
[2025-03-24 11:38:50] INFO server_utils.py:447: Successfully stopped 1 VLLM servers.


[True]

In [32]:
if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state
    spark.stop()