<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark LLM Inference: DeepSeek-R1 Reasoning Q/A

In this notebook, we demonstrate distributed batch inference with [DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1), using open weights on Huggingface.

We use [DeepSeek-R1-Distill-Llama-8B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B) as demonstration. DeepSeek's distilled models are based on open-source LLMs (such as Llama/Qwen), and are fine-tuned using samples generated by DeepSeek-R1. We'll show how to use the model to reason through word problems.

**Note:** Running this model on GPU with 16-bit precision requires **~18GB** of GPU RAM. Make sure your instances have sufficient GPU capacity.

In [None]:
# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.
# See (https://github.com/huggingface/transformers/issues/5486) for more info. 
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

Check the cluster environment to handle any platform-specific configurations.

In [1]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_IMAGE_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

In [2]:
# For cloud environments, load the model to the distributed file system.
if on_databricks:
    models_dir = "/dbfs/FileStore/spark-dl-models"
    dbutils.fs.mkdirs("/FileStore/spark-dl-models")
    model_path = f"{models_dir}/deepseek-r1-distill-llama-8b"
elif on_dataproc:
    models_dir = "/mnt/gcs/spark-dl-models"
    os.mkdir(models_dir) if not os.path.exists(models_dir) else None
    model_path = f"{models_dir}/deepseek-r1-distill-llama-8b"
else:
    model_path = os.path.abspath("deepseek-r1-distill-llama-8b")

Download the model from huggingface hub.

In [None]:
from huggingface_hub import snapshot_download

model_path = snapshot_download(
    repo_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    local_dir=model_path
)

## Warmup: Running locally

**Note:** If the driver node does not have sufficient GPU capacity, proceed to the PySpark section.

In [3]:
import torch
from transformers import pipeline

pipe = pipeline("text-generation", model=model_path, torch_dtype=torch.bfloat16, device="cuda")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda


In [10]:
res = pipe(["How many r's are there in strawberry?"], max_new_tokens=512, temperature=0.1)
print("\n", res[0][0]['generated_text'])


 How many r's are there in strawberry? Let's count them.

First, I'll write down the word: S T R A W B E R R Y.

Now, I'll go through each letter one by one.

1. S - no R.
2. T - no R.
3. R - that's one R.
4. A - no R.
5. W - no R.
6. B - no R.
7. E - no R.
8. R - that's two R's.
9. R - that's three R's.
10. Y - no R.

So, in total, there are three R's in the word strawberry.
</think>

To determine how many **r's** are in the word **strawberry**, let's follow these steps:

1. **Write down the word:**
   
   S T R A W B E R R Y

2. **Identify and count each occurrence of the letter R:**
   
   - **1.** S - no R
   - **2.** T - no R
   - **3.** R - **1 R**
   - **4.** A - no R
   - **5.** W - no R
   - **6.** B - no R
   - **7.** E - no R
   - **8.** R - **2 R's**
   - **9.** R - **3 R's**
   - **10.** Y - no R

3. **Total count of R's:**
   
   There are **3 R's** in the word **strawberry**.

\boxed{3}


In [3]:
res = pipe(["Which number is bigger: 9.9 or 9.11?"], max_new_tokens=512, temperature=0.1)
print("\n", res[0][0]['generated_text'])


 Which number is bigger: 9.9 or 9.11? Let's see.

First, I need to compare the whole number parts of both numbers. Both 9.9 and 9.11 have the same whole number part, which is 9.

Since the whole numbers are equal, I'll compare the decimal parts. For 9.9, the decimal part is 0.9, and for 9.11, the decimal part is 0.11.

To make it easier, I can express 0.9 as 0.90. Now, comparing 0.90 and 0.11, it's clear that 0.90 is greater than 0.11.

Therefore, 9.9 is bigger than 9.11.
</think>

To determine which number is larger between **9.9** and **9.11**, let's compare them step by step.

1. **Compare the Whole Numbers:**
   - Both numbers have the same whole number part: **9**.
   
2. **Compare the Decimal Parts:**
   - **9.9** can be written as **9.90**.
   - **9.11** remains **9.11**.
   
3. **Analyze the Decimal Comparison:**
   - Compare the tenths place:
     - **9.90** has **9** in the tenths place.
     - **9.11** has **1** in the tenths place.
   - Since **9 > 1**, **9.90** is greater

In [4]:
import torch

# Unload the model from GPU memory.
del pipe
torch.cuda.empty_cache()

## PySpark

In [3]:
from pyspark.sql.types import *
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col, struct, length
from pyspark.ml.functions import predict_batch_udf

In [4]:
import os
import pandas as pd
import datasets
from datasets import load_dataset
datasets.disable_progress_bars()

#### Create Spark Session

For local standalone clusters, we'll connect to the cluster and create the Spark Session.  
For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).

In [6]:
conf = SparkConf()

if 'spark' not in globals():
    if on_standalone:
        import socket
        conda_env = os.environ.get("CONDA_PREFIX")
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")

    conf.set("spark.executor.cores", "8")
    conf.set("spark.task.maxFailures", "1")
    conf.set("spark.task.resource.gpu.amount", "0.125")
    conf.set("spark.executor.resource.gpu.amount", "1")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")

spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

25/02/10 09:40:01 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
25/02/10 09:40:01 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/10 09:40:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


#### Load DataFrame

Load the first 500 samples of the [Orca Math Word Problems dataset](https://huggingface.co/datasets/microsoft/orca-math-word-problems-200k) from Huggingface and store in a Spark Dataframe.

In [7]:
dataset = load_dataset("microsoft/orca-math-word-problems-200k", split="train", streaming=True)
dataset = pd.Series([sample["question"] for sample in dataset.take(500)])

In [8]:
df = spark.createDataFrame(dataset, schema=StringType()).withColumnRenamed("value", "question")
df.show(5, truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                            question|
+----------------------------------------------------------------------------------------------------+
|Jungkook is the 5th place. Find the number of people who crossed the finish line faster than Jung...|
|A number divided by 10 is 6. Yoongi got the result by subtracting 15 from a certain number. What ...|
|Dongju selects a piece of paper with a number written on it, and wants to make a three-digit numb...|
|You wanted to subtract 46 from a number, but you accidentally subtract 59 and get 43. How much do...|
|The length of one span of Jinseo is about 12 centimeters (cm). When Jinseo measured the length of...|
+----------------------------------------------------------------------------------------------------+
only showing top 5 rows



                                                                                

In [9]:
data_path = "spark-dl-datasets/orca_math"
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-datasets")
    data_path = "dbfs:/FileStore/" + data_path

df.write.mode("overwrite").json(data_path)

## Triton Inference Server
We'll demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-server.png" alt="drawing" width="700"/>

In [10]:
from functools import partial

Import the helper class from server_utils.py:

In [11]:
sc.addPyFile("server_utils.py")

from server_utils import TritonServerManager

Define the Triton Server function:

In [12]:
def triton_server(ports, model_path):
    import time
    import signal
    import numpy as np
    import torch
    from transformers import pipeline
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton, TritonConfig
    from pyspark import TaskContext

    print(f"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pipe = pipeline("text-generation", model=model_path, torch_dtype=torch.bfloat16, device=device)
    print(f"SERVER: Using {device} device.")

    @batch
    def _infer_fn(**inputs):
        prompts = np.squeeze(inputs["prompts"]).tolist()
        decoded_prompts = [p.decode("utf-8") for p in prompts]
        # limit responses to 256 tokens, since reasoning tasks can take a while
        responses = pipe(decoded_prompts, max_new_tokens=256, temperature=0.2, return_full_text=False)
        return {
            "responses": np.array([r[0]['generated_text'] for r in responses]).reshape(-1, 1)
        }

    workspace_path = f"/tmp/triton_{time.strftime('%m_%d_%M_%S')}"
    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])
    with Triton(config=triton_conf, workspace=workspace_path) as triton:
        triton.bind(
            model_name="deepseek-r1",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="prompts", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="responses", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=16,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def _stop_triton(signum, frame):
            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, _stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

#### Start Triton servers

The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:
- Find available ports for HTTP/gRPC/metrics
- Deploy a server on each node via stage-level scheduling
- Gracefully shutdown servers across nodes

In [14]:
model_name = "deepseek-r1"
server_manager = TritonServerManager(model_name=model_name, model_path=model_path)

In [15]:
# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}
server_manager.start_servers(triton_server, wait_retries=24)  # allow up to 2 minutes for model loading

2025-02-10 09:40:17,442 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)
2025-02-10 09:40:17,442 - INFO - Starting 1 servers.
                                                                                

{'cb4ae00-lcedt': (272659, [7000, 7001, 7002])}

#### Define client function

Get the hostname -> url mapping from the server manager:

In [16]:
host_to_grpc_url = server_manager.host_to_grpc_url  # or server_manager.host_to_http_url

Define the Triton inference function, which returns a predict function for batch inference through the server:

In [17]:
def triton_fn(model_name, host_to_url):
    import socket
    import numpy as np
    from pytriton.client import ModelClient

    url = host_to_url[socket.gethostname()]
    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, inference_timeout_s=500) as client:
            flattened = np.squeeze(inputs).tolist()
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["responses"], -1)
            return result_data
        
    return infer_batch

In [18]:
generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=2)

#### Load and preprocess DataFrame

We'll select a few of the shorter questions for demonstration, since reasoning tasks can take a while.

In [19]:
df = spark.read.json(data_path)
df = df.filter(length(col("question")) <= 100).limit(16).repartition(8).cache()

#### Run Inference

In [20]:
%%time
# first pass caches model/fn
preds = df.withColumn("response", generate(col("question")))
results = preds.collect()



CPU times: user 18.6 ms, sys: 8.31 ms, total: 26.9 ms
Wall time: 1min 46s


                                                                                

In [30]:
%%time
preds = df.withColumn("response", generate("question"))
results = preds.collect()



CPU times: user 9.55 ms, sys: 4.51 ms, total: 14.1 ms
Wall time: 1min 45s


                                                                                

Sample output:

In [21]:
print(f"Q: {results[2].question} \n")
print(f"A: {results[2].response} \n")

Q: There are 9 dogs and 23 cats. How many more cats are there than dogs? 

A:  Let me think. So, I have 23 cats and 9 dogs. To find out how many more cats there are than dogs, I need to subtract the number of dogs from the number of cats. That would be 23 minus 9. Let me do the subtraction: 23 minus 9 is 14. So, there are 14 more cats than dogs.

Wait, let me double-check that. If I have 9 dogs and 23 cats, subtracting the number of dogs from the number of cats should give me the difference. So, 23 minus 9 is indeed 14. Yeah, that seems right. I don't think I made a mistake there. So, the answer is 14 more cats than dogs.

**Final Answer**
The number of cats exceeds the number of dogs by \boxed{14}.
\boxed{14}
</think>

To determine how many more cats there are than dogs, we subtract the number of dogs from the number of cats. 

Given:
- Number of cats = 23
- Number of dogs = 9

The calculation is:
\[ 23 - 9 = 14 \]

Thus, there are 14 more cats than dogs.

\[
\boxed{14}
\] 



#### Shut down server on each executor

In [30]:
server_manager.stop_servers()

2025-02-10 09:43:36,499 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)


2025-02-10 09:43:41,701 - INFO - Sucessfully stopped 1 servers.                 


[True]

In [31]:
if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state
    spark.stop()