<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark LLM Inference: Qwen-2.5 Text Summarization

In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct), using open weights on Huggingface.

The Qwen-2.5-7b-instruct is an instruction-fine-tuned version of the Qwen-2.5-7b base model. We'll show how to use the model to perform text summarization.

**Note:** Running this model on GPU with 16-bit precision requires **~16GB** of GPU RAM. Make sure your instances have sufficient GPU capacity.

The dataset we'll use requires Zstandard compression.

In [16]:
%pip install zstandard

Collecting zstandard
  Downloading zstandard-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Downloading zstandard-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/5.4 MB[0m [31m66.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: zstandard
Successfully installed zstandard-0.23.0
Note: you may need to restart the kernel to use updated packages.


In [1]:
# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.
# See (https://github.com/huggingface/transformers/issues/5486) for more info. 
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

Check the cluster environment to handle any platform-specific configurations.

In [2]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_IMAGE_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

In [3]:
# For cloud environments, load the model to the distributed file system.
if on_databricks:
    models_dir = "/dbfs/FileStore/spark-dl-models"
    dbutils.fs.mkdirs("/FileStore/spark-dl-models")
    model_path = f"{models_dir}/qwen-2.5-7b"
elif on_dataproc:
    models_dir = "/mnt/gcs/spark-dl-models"
    os.mkdir(models_dir) if not os.path.exists(models_dir) else None
    model_path = f"{models_dir}/qwen-2.5-7b"
else:
    model_path = os.path.abspath("qwen-2.5-7b")

Download the model from huggingface hub.

In [None]:
from huggingface_hub import snapshot_download

model_path = snapshot_download(
    repo_id="Qwen/Qwen2.5-7B-Instruct",
    local_dir=model_path
)

## Warmup: Running locally

**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section.

In [16]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [17]:
system_prompt = {
    "role": "system",
    "content": "You are a knowledgeable AI assistant that provides accurate answers to questions."
}

queries = [
    "How many vowels are in 'elephant'?",
    "What is the square root of 16?",
    "How many planets are in our solar system?"
]

prompts = [
    [
        system_prompt,
        {"role": "user", "content": query}
    ] for query in queries
]

text = tokenizer.apply_chat_template(
    prompts,
    tokenize=False,
    add_generation_prompt=True,
)

model_inputs = tokenizer(text, return_tensors="pt", padding=True).to(model.device)

In [23]:
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=256,
)

outputs = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens = True)

In [24]:
for query, output in zip(queries, outputs):
    print(f"Q: {query}\nA: {output}\n")

Q: How many vowels are in 'elephant'?
A: The word "elephant" contains 3 vowels. The vowels are 'e', 'e', and 'a'.

Q: What is the square root of 16?
A: The square root of 16 is 4, because \(4 \times 4 = 16\).

Q: How many planets are in our solar system?
A: There are eight planets in our solar system. They are, in order from the Sun:

1. Mercury
2. Venus
3. Earth
4. Mars
5. Jupiter
6. Saturn
7. Uranus
8. Neptune

Pluto was previously considered the ninth planet but is now classified as a dwarf planet.



In [25]:
import torch
del model
torch.cuda.empty_cache()

## PySpark

In [4]:
import pandas as pd
from pyspark.sql.types import *
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, struct, length, lit, concat
from pyspark.ml.functions import predict_batch_udf

In [5]:
import os
import datasets
from datasets import load_dataset
datasets.disable_progress_bars()

#### Create Spark Session

For local standalone clusters, we'll connect to the cluster and create the Spark Session.  
For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).

In [6]:
conf = SparkConf()

if 'spark' not in globals():
    if on_standalone:
        import socket
        conda_env = os.environ.get("CONDA_PREFIX")
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")

    conf.set("spark.executor.cores", "8")
    conf.set("spark.task.maxFailures", "1")
    conf.set("spark.task.resource.gpu.amount", "0.125")
    conf.set("spark.executor.resource.gpu.amount", "1")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")

spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

25/02/16 11:48:57 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
25/02/16 11:48:57 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/16 11:48:57 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


#### Load and Preprocess DataFrame

Load the first 500 samples of the [PUBMED abstracts dataset](https://huggingface.co/datasets/casinca/PUBMED_title_abstracts_2019_baseline) from Huggingface and store in a Spark Dataframe.

In [7]:
pubmed_dataset = load_dataset("casinca/PUBMED_title_abstracts_2019_baseline", split="train", streaming=True)
pubmed_pds = pd.Series([sample["text"] for sample in pubmed_dataset.take(500)])

In [8]:
df = spark.createDataFrame(pubmed_pds, schema=StringType())

In [9]:
df.show(5, truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                               value|
+----------------------------------------------------------------------------------------------------+
|Epidemiology of hypoxaemia in children with acute lower respiratory infection.\nTo determine the ...|
|Clinical signs of hypoxaemia in children with acute lower respiratory infection: indicators of ox...|
|Hypoxaemia in children with severe pneumonia in Papua New Guinea.\nTo investigate the severity an...|
|Oxygen concentrators and cylinders.\nA comparison is made between oxygen cylinders and oxygen con...|
|Oxygen supply in rural africa: a personal experience.\nOxygen is one of the essential medical sup...|
+----------------------------------------------------------------------------------------------------+
only showing top 5 rows



                                                                                

Format each sample into the chat template, including a system prompt to guide generation.

In [11]:
system_prompt = '''You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary 
of a research abstract that captures the main objective, methodology, and key findings, using clear 
language while preserving technical accuracy and quantitative results.'''

df = df.select(
    concat(
        lit("<|im_start|>system\n"),
        lit(system_prompt),
        lit("<|im_end|>\n<|im_start|>user\n"),
        col("value"),
        lit("<|im_end|>\n<|im_start|>assistant\n")
    ).alias("prompt")
)

In [None]:
print(df.take(1)[0].prompt)

<|im_start|>system
You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary 
of a research abstract that captures the main objective, methodology, and key findings, using clear 
language while preserving technical accuracy and quantitative results.<|im_end|>
<|im_start|>user
Epidemiology of hypoxaemia in children with acute lower respiratory infection.
To determine the prevalence of hypoxaemia in children aged under 5 years suffering acute lower respiratory infections (ALRI), the risk factors for hypoxaemia in children under 5 years of age with ALRI, and the association of hypoxaemia with an increased risk of dying in children of the same age. Systematic review of the published literature. Out-patient clinics, emergency departments and hospitalisation wards in 23 health centres from 10 countries. Cohort studies reporting the frequency of hypoxaemia in children under 5 years of age with ALRI, and the association between hypoxaemia and the risk of dying. Prevale

In [13]:
data_path = "spark-dl-datasets/pubmed_abstracts"
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-datasets")
    data_path = "dbfs:/FileStore/" + data_path

df.write.mode("overwrite").parquet(data_path)

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-server.png" alt="drawing" width="700"/>

In [14]:
from functools import partial

Import the helper class from server_utils.py:

In [15]:
sc.addPyFile("server_utils.py")

from server_utils import TritonServerManager

Define the Triton Server function:

In [16]:
def triton_server(ports, model_path):
    import time
    import signal
    import torch
    import numpy as np
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton, TritonConfig
    from pyspark import TaskContext
    from transformers import AutoModelForCausalLM, AutoTokenizer

    print(f"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")

    @batch
    def _infer_fn(**inputs):
        prompts = np.squeeze(inputs["prompts"]).tolist()
        print(f"SERVER: Received batch of size {len(prompts)}")
        decoded_prompts = [p.decode("utf-8") for p in prompts]
        tokenized_inputs = tokenizer(decoded_prompts, padding=True, return_tensors="pt").to(model.device)
        generated_ids = model.generate(**tokenized_inputs, max_new_tokens=256)
        outputs = tokenizer.batch_decode(generated_ids[:, tokenized_inputs.input_ids.shape[1]:], skip_special_tokens = True)
        return {
            "outputs": np.array(outputs).reshape(-1, 1)
        }

    workspace_path = f"/tmp/triton_{time.strftime('%m_%d_%M_%S')}"
    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])
    with Triton(config=triton_conf, workspace=workspace_path) as triton:
        triton.bind(
            model_name="qwen-2.5",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="prompts", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="outputs", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=64,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def _stop_triton(signum, frame):
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, _stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

#### Start Triton servers

The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:
- Find available ports for HTTP/gRPC/metrics
- Deploy a server on each node via stage-level scheduling
- Gracefully shutdown servers across nodes

In [18]:
model_name = "qwen-2.5"
server_manager = TritonServerManager(model_name=model_name, model_path=model_path)

In [19]:
# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}
server_manager.start_servers(triton_server, wait_retries=24)  # allow up to 2 minutes for model loading

2025-02-16 11:49:25,237 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)
2025-02-16 11:49:25,239 - INFO - Starting 1 servers.


                                                                                

{'cb4ae00-lcedt': (3490378, [7000, 7001, 7002])}

#### Define client function

Get the hostname -> url mapping from the server manager:

In [20]:
host_to_grpc_url = server_manager.host_to_grpc_url  # or server_manager.host_to_http_url

In [21]:
def triton_fn(model_name, host_to_url):
    import socket
    import numpy as np
    from pytriton.client import ModelClient

    url = host_to_url[socket.gethostname()]
    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, inference_timeout_s=500) as client:
            flattened = np.squeeze(inputs).tolist()
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["outputs"], -1)
            return result_data
        
    return infer_batch

In [22]:
generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=8)

#### Load DataFrame

We'll parallelize over a small set of prompts for demonstration.

In [23]:
df = spark.read.parquet(data_path).limit(64).repartition(8)

#### Run Inference

In [24]:
%%time
# first pass caches model/fn
preds = df.withColumn("outputs", generate(col("prompt")))
results = preds.collect()



CPU times: user 10.5 ms, sys: 6.63 ms, total: 17.1 ms
Wall time: 23.7 s


                                                                                

In [25]:
%%time
preds = df.withColumn("outputs", generate(col("prompt")))
results = preds.collect()



CPU times: user 8.1 ms, sys: 4.47 ms, total: 12.6 ms
Wall time: 21.7 s


                                                                                

In [28]:
print(f"Q: {results[0].prompt} \n")
print(f"A: {results[0].outputs} \n")

Q: <|im_start|>system
You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary 
of a research abstract that captures the main objective, methodology, and key findings, using clear 
language while preserving technical accuracy and quantitative results.<|im_end|>
<|im_start|>user
Oral health promotion evaluation--time for development.
Increasing emphasis is now being placed upon the evaluation of health service interventions to demonstrate their effects. A series of effectiveness reviews of the oral health education and promotion literature has demonstrated that many of these interventions are poorly and inadequately evaluated. It is therefore difficult to determine the effectiveness of many interventions. Based upon developments from the field of health promotion research this paper explores options for improving the quality of oral health promotion evaluation. It is essential that the methods and measures used in the evaluation of oral health promotion are app

#### Shut down server on each executor

In [29]:
server_manager.stop_servers()

2025-02-16 11:51:42,365 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)


2025-02-16 11:51:47,609 - INFO - Sucessfully stopped 1 servers.                 


[True]

In [30]:
if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state
    spark.stop()