<img src="http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png" width="90px">

# PySpark LLM Inference: Gemma-7b Code Comprehension

In this notebook, we demonstrate distributed inference with the Google [Gemma-7b-instruct](https://huggingface.co/google/gemma-7b-it) LLM, using open-weights on Huggingface.

The Gemma-7b-instruct is an instruction-fine-tuned version of the Gemma-7b base model. We'll show how to use the model to perform code comprehension tasks.

**Note:** Running this model on GPU with 16-bit precision requires **~18 GB** of GPU RAM. Make sure your instances have sufficient GPU capacity.

In [None]:
# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.
# See (https://github.com/huggingface/transformers/issues/5486) for more info. 
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

Check the cluster environment to handle any platform-specific configurations.

In [13]:
on_databricks = os.environ.get("DATABRICKS_RUNTIME_VERSION", False)
on_dataproc = os.environ.get("DATAPROC_IMAGE_VERSION", False)
on_standalone = not (on_databricks or on_dataproc)

In [14]:
# For cloud environments, load the model to the distributed file system.
if on_databricks:
    models_dir = "/dbfs/FileStore/spark-dl-models"
    dbutils.fs.mkdirs("/FileStore/spark-dl-models")
    model_path = f"{models_dir}/gemma-7b-it"
elif on_dataproc:
    models_dir = "/mnt/gcs/spark-dl-models"
    os.mkdir(models_dir) if not os.path.exists(models_dir) else None
    model_path = f"{models_dir}/gemma-7b-it"
else:
    model_path = os.path.abspath("gemma-7b-it")

First visit the [Gemma Huggingface repository](https://huggingface.co/google/gemma-7b-it) to accept the terms to access the model, then login via huggingface_hub.

In [None]:
from huggingface_hub import login

login()

Once you have access, you can download the model:

In [None]:
from huggingface_hub import snapshot_download

model_path = snapshot_download(
    repo_id="google/gemma-7b-it",
    local_dir=model_path,
    ignore_patterns="*.gguf"
)

## Warmup: Running locally

**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section.

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path,
                                             device_map="auto",
                                             torch_dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [10]:
input_text = "Write me a poem about Apache Spark."
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.1, do_sample=True)
print(tokenizer.decode(outputs[0]))

<bos>Write me a poem about Apache Spark.

In the realm of big data, a spark ignites,
A framework born to conquer the night.
Apache Spark, a lightning-fast tool,
For processing data, swift and cool.

With its resilient distributed architecture,
It slices through terabytes with grace.
No longer bound by memory's plight,
Spark empowers us to analyze with might.

From Python to Scala, it's a versatile spark,
Unveiling insights hidden in the dark.



In [4]:
import torch

# Unload the model from GPU memory.
del model
torch.cuda.empty_cache()

## PySpark

In [4]:
from pyspark.sql.types import *
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.ml.functions import predict_batch_udf

In [5]:
import os
import pandas as pd
import datasets
from datasets import load_dataset
datasets.disable_progress_bars()

#### Create Spark Session

For local standalone clusters, we'll connect to the cluster and create the Spark Session.  
For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc).

In [6]:
conf = SparkConf()

if 'spark' not in globals():
    if on_standalone:
        import socket
        conda_env = os.environ.get("CONDA_PREFIX")
        hostname = socket.gethostname()
        conf.setMaster(f"spark://{hostname}:7077")
        conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
        conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")

    conf.set("spark.executor.cores", "8")
    conf.set("spark.task.maxFailures", "1")
    conf.set("spark.task.resource.gpu.amount", "0.125")
    conf.set("spark.executor.resource.gpu.amount", "1")
    conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    conf.set("spark.python.worker.reuse", "true")

spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

25/02/10 09:44:33 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
25/02/10 09:44:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/10 09:44:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


#### Load DataFrame

Load the first 500 samples of the [Code Comprehension dataset](https://huggingface.co/datasets/imbue/code-comprehension) from Huggingface and store in a Spark Dataframe.

In [7]:
dataset = load_dataset("imbue/code-comprehension", split="train", streaming=True)
dataset = pd.Series([sample["question"] for sample in dataset.take(500)])

In [8]:
df = spark.createDataFrame(dataset, schema=StringType()).withColumnRenamed("value", "prompt")

In [10]:
df.show(5, truncate=100)

+----------------------------------------------------------------------------------------------------+
|                                                                                              prompt|
+----------------------------------------------------------------------------------------------------+
|If we execute the code below, what will `result` be equal to?\n\n```python\nN = 'quz'\nN += 'bar'...|
|```python\nresult = 9 - 9 - 1 - 7 - 9 - 1 + 9 - 2 + 6 - 4 - 8 - 1\n```\n\nOut of these options, w...|
|```python\nx = 'bas'\nD = 'bar'.swapcase()\nx = len(x)\nx = str(x)\nnu = 'bar'.isnumeric()\nx += ...|
|If we execute the code below, what will `result` be equal to?\n\n```python\n\nl = 'likewise'\nmat...|
|```python\nresult = 'mazda' + 'isolated' + 'mistakes' + 'grew' + 'raid' + 'junk' + 'jamaica' + 'c...|
+----------------------------------------------------------------------------------------------------+
only showing top 5 rows



In [11]:
print(df.take(1)[0].prompt)

If we execute the code below, what will `result` be equal to?

```python
N = 'quz'
N += 'bar'
N = N.swapcase()
N = len(N)
mu = 'bar'.strip()
N = str(N)
Q = N.isalpha()
if N == 'bawr':
    N = 'BAWR'.lower()
N = N + N
N = '-'.join([N, N, N, 'foo'])
if mu == N:
    N = 'bar'.upper()
gamma = 'BAZ'.lower()

result = N
```


In [12]:
data_path = "spark-dl-datasets/code_comprehension"
if on_databricks:
    dbutils.fs.mkdirs("/FileStore/spark-dl-datasets")
    data_path = "dbfs:/FileStore/" + data_path

df.write.mode("overwrite").json(data_path)

## Using Triton Inference Server
In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  
We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  

The process looks like this:
- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.
- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.
- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.

<img src="../images/spark-server.png" alt="drawing" width="700"/>

In [15]:
from functools import partial

Import the helper class from server_utils.py:

In [16]:
sc.addPyFile("server_utils.py")

from server_utils import TritonServerManager

Define the Triton Server function:

In [37]:
def triton_server(ports, model_path):
    import time
    import signal
    import numpy as np
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from pytriton.decorators import batch
    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
    from pytriton.triton import Triton, TritonConfig
    from pyspark import TaskContext

    print(f"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
    print(f"SERVER: Using {device} device.")

    @batch
    def _infer_fn(**inputs):
        prompts = np.squeeze(inputs["prompts"]).tolist()
        print(f"SERVER: Received batch of size {len(prompts)}")
        decoded_prompts = [p.decode("utf-8") for p in prompts]
        tokenized_inputs = tokenizer(decoded_prompts, padding=True, return_tensors="pt").to(device)
        outputs = model.generate(**tokenized_inputs, max_new_tokens=256, temperature=0.1, do_sample=True)
        # Decode only the model output (excluding the input prompt) and remove special tokens.
        responses = np.array(tokenizer.batch_decode(outputs[:, tokenized_inputs.input_ids.shape[1]:], skip_special_tokens = True))
        return {
            "responses": responses.reshape(-1, 1),
        }

    workspace_path = f"/tmp/triton_{time.strftime('%m_%d_%M_%S')}"
    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])
    with Triton(config=triton_conf, workspace=workspace_path) as triton:
        triton.bind(
            model_name="gemma-7b",
            infer_func=_infer_fn,
            inputs=[
                Tensor(name="prompts", dtype=object, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="responses", dtype=object, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=16,
                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms
            ),
            strict=True,
        )

        def _stop_triton(signum, frame):
            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.
            print("SERVER: Received SIGTERM. Stopping Triton server.")
            triton.stop()

        signal.signal(signal.SIGTERM, _stop_triton)

        print("SERVER: Serving inference")
        triton.serve()

#### Start Triton servers

The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:
- Find available ports for HTTP/gRPC/metrics
- Deploy a server on each node via stage-level scheduling
- Gracefully shutdown servers across nodes

In [39]:
model_name = "gemma-7b"
server_manager = TritonServerManager(model_name=model_name, model_path=model_path)

In [None]:
# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}
server_manager.start_servers(triton_server, wait_retries=24)  # allow up to 2 minutes for model loading

2025-02-10 09:06:38,803 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)
2025-02-10 09:06:38,805 - INFO - Starting 1 servers.


                                                                                

{'cb4ae00-lcedt': (252119, [7000, 7001, 7002])}

#### Define client function

Get the hostname -> url mapping from the server manager:

In [41]:
host_to_grpc_url = server_manager.host_to_grpc_url  # or server_manager.host_to_http_url

Define the Triton inference function, which returns a predict function for batch inference through the server:

In [42]:
def triton_fn(model_name, host_to_url):
    import socket
    import numpy as np
    from pytriton.client import ModelClient

    url = host_to_url[socket.gethostname()]
    print(f"Connecting to Triton model {model_name} at {url}.")

    def infer_batch(inputs):
        with ModelClient(url, model_name, inference_timeout_s=500) as client:
            flattened = np.squeeze(inputs).tolist()
            # Encode batch
            encoded_batch = [[text.encode("utf-8")] for text in flattened]
            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)
            # Run inference
            result_data = client.infer_batch(encoded_batch_np)
            result_data = np.squeeze(result_data["responses"], -1)
            return result_data
        
    return infer_batch

In [43]:
generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=4)

#### Load and preprocess DataFrame

We'll parallelize over a small set of questions for demonstration.

In [44]:
df = spark.read.json(data_path).limit(32).repartition(8)

#### Run Inference

In [45]:
%%time
# first pass caches model/fn
preds = df.withColumn("response", generate(col("prompt")))
results = preds.collect()



CPU times: user 5.6 ms, sys: 3.51 ms, total: 9.11 ms
Wall time: 28.1 s


                                                                                

In [51]:
%%time
preds = df.withColumn("response", generate("prompt"))
results = preds.collect()



CPU times: user 8.12 ms, sys: 3.13 ms, total: 11.2 ms
Wall time: 23.1 s


                                                                                

Sample output:

In [54]:
print(f"Q: {results[2].prompt} \n")
print(f"A: {results[2].response} \n")

Q: ```python
result = ['mirrors', 'limousines', 'meaningful', 'cats', UNKNOWN, 'striking', 'wings', 'injured', 'wishlist', 'granny'].index('oracle')
print(result)
```

The code above has one or more parts replaced with the word UNKNOWN. Knowing that running the code prints `4` to the console, what should go in place of UNKNOWN? 

A: 

The answer is `oracle`.

The code is searching for the index of the word `oracle` in the list `result`, and the index is returned as `4`. 



#### Shut down server on each executor

In [55]:
server_manager.stop_servers()

2025-02-10 09:11:11,880 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)


2025-02-10 09:11:17,105 - INFO - Sucessfully stopped 1 servers.                 


[True]

In [56]:
if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state
    spark.stop()