# PySpark Huggingface Inferencing
## Sentence Transformers with PyTorch

From: https://huggingface.co/sentence-transformers

In [1]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

#Sentences we want to encode. Example:
sentence = ['This framework generates embeddings for each input sentence']


#Sentences are encoded by calling model.encode()
embedding = model.encode(sentence)

  from tqdm.autonotebook import tqdm, trange


In [2]:
print(embedding[0][:20])

[-0.17621444  0.1206013  -0.29362372 -0.22985819 -0.08229247  0.2377093
  0.33998525 -0.7809643   0.11812777  0.16337365 -0.13771524  0.24028276
  0.4251256   0.17241786  0.10527937  0.5181643   0.062222    0.39928585
 -0.18165241 -0.58557856]


## PySpark

## Inference using Spark DL API
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [3]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, struct
from pyspark.sql.types import ArrayType, FloatType
from pyspark.sql import SparkSession
from pyspark import SparkConf
from datasets import load_dataset

In [4]:
import os
conda_env = os.environ.get("CONDA_PREFIX")

conf = SparkConf()
if 'spark' not in globals():
    # If Spark is not already started with Jupyter, attach to Spark Standalone
    import socket
    hostname = socket.gethostname()
    conf.setMaster(f"spark://{hostname}:7077") # assuming Master is on default port 7077
conf.set("spark.task.maxFailures", "1")
conf.set("spark.driver.memory", "8g")
conf.set("spark.executor.memory", "8g")
conf.set("spark.pyspark.python", f"{conda_env}/bin/python")
conf.set("spark.pyspark.driver.python", f"{conda_env}/bin/python")
conf.set("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled", "false")
conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "true")
conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
conf.set("spark.python.worker.reuse", "true")
# Create Spark Session
spark = SparkSession.builder.appName("spark-dl-examples").config(conf=conf).getOrCreate()
sc = spark.sparkContext

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
24/10/08 00:19:28 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)
24/10/08 00:19:28 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/08 00:19:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [5]:
# load IMDB reviews (test) dataset and write to parquet
data = load_dataset("imdb", split="test")

lines = []
for example in data:
    lines.append([example["text"].split(".")[0]])

len(lines)

df = spark.createDataFrame(lines, ['lines']).repartition(10)
df.schema

df.write.mode("overwrite").parquet("imdb_test")

                                                                                

In [6]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100).cache()

In [7]:
df.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|                                       This is so overly clichéd you'll want to switch it off after the first 45 minutes|
|                                                                                   I was very disappointed by this movie|
|                                                                             I think vampire movies (usually) are wicked|
|                           Though not a complete waste of time, 'Eighteen' really wasn't all sweet as it pretended to be|
|This film did well at the box office, and the producers of this mess thought the stars had such good chemistry in thi...|
|               

In [8]:
def predict_batch_fn():
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
    def predict(inputs):
        return model.encode(inputs.tolist())
    return predict

In [9]:
encode = predict_batch_udf(predict_batch_fn,
                           return_type=ArrayType(FloatType()),
                           batch_size=10)

In [10]:
%%time
# first pass caches model/fn
embeddings = df.withColumn("encoding", encode(struct("lines")))
results = embeddings.collect()

[Stage 9:>                                                          (0 + 1) / 1]

CPU times: user 4.34 ms, sys: 4.15 ms, total: 8.48 ms
Wall time: 2.58 s


                                                                                

In [11]:
%%time
embeddings = df.withColumn("encoding", encode("lines"))
results = embeddings.collect()

[Stage 11:>                                                         (0 + 1) / 1]

CPU times: user 1.76 ms, sys: 4.89 ms, total: 6.65 ms
Wall time: 2.47 s


                                                                                

In [12]:
%%time
embeddings = df.withColumn("encoding", encode(col("lines")))
results = embeddings.collect()

[Stage 13:>                                                         (0 + 1) / 1]

CPU times: user 1.55 ms, sys: 6.05 ms, total: 7.6 ms
Wall time: 2.46 s


                                                                                

In [13]:
embeddings.show(truncate=60)

[Stage 15:>                                                         (0 + 1) / 1]

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       lines|                                                    encoding|
+------------------------------------------------------------+------------------------------------------------------------+
|This is so overly clichéd you'll want to switch it off af...|[-0.06755405, -0.13365394, 0.36675274, -0.2772311, -0.085...|
|                       I was very disappointed by this movie|[-0.05903806, 0.16684641, 0.16768408, 0.10940918, 0.18100...|
|                 I think vampire movies (usually) are wicked|[0.025601083, -0.5308639, -0.319133, -0.013351389, -0.338...|
|Though not a complete waste of time, 'Eighteen' really wa...|[0.20991832, 0.5228605, 0.44517252, -0.031682555, -0.4117...|
|This film did well at the box office, and the producers o...|[0.18097948, -0.03622232, -0.34149718, 0.061557338, -0.06...|
|Peter C

                                                                                

### Using Triton Inference Server

Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:
```
conda create -n huggingface-torch -c conda-forge python=3.10.0
conda activate huggingface-torch

export PYTHONNOUSERSITE=True
pip install numpy==1.26.4 conda-pack sentencepiece sentence_transformers transformers

conda-pack  # huggingface-torch.tar.gz
```

In [14]:
import numpy as np
import os
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, struct
from pyspark.sql.types import ArrayType, FloatType

In [15]:
%%bash
# copy custom model to expected layout for Triton
rm -rf models
mkdir -p models
cp -r models_config/hf_transformer_torch models

# add custom execution environment
cp huggingface-torch.tar.gz models

#### Start Triton Server on each executor

In [16]:
num_executors = 1
triton_models_dir = "{}/models".format(os.getcwd())
huggingface_cache_dir = "{}/.cache/huggingface".format(os.path.expanduser('~'))
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:24.08-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            environment=[
                "TRANSFORMERS_CACHE=/cache"
            ],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="512M",
            volumes={
                triton_models_dir: {"bind": "/models", "mode": "ro"},
                huggingface_cache_dir: {"bind": "/cache", "mode": "rw"}
            }
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)

    return [True]

nodeRDD.barrier().mapPartitions(start_triton).collect()

                                                                                

[True]

#### Run inference

In [17]:
from functools import partial
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, struct
from pyspark.sql.types import ArrayType, FloatType

In [18]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100).cache()

24/10/08 00:20:24 WARN CacheManager: Asked to cache already cached data.


In [19]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool_),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [20]:
encode = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="hf_transformer_torch"),
                           return_type=ArrayType(FloatType()),
                           input_tensor_shapes=[[1]],
                           batch_size=100)

In [21]:
%%time
# first pass caches model/fn
embeddings = df.withColumn("encoding", encode(struct("lines")))
results = embeddings.collect()

CPU times: user 4.65 ms, sys: 2.85 ms, total: 7.49 ms
Wall time: 480 ms


In [22]:
%%time
embeddings = df.withColumn("encoding", encode("lines"))
results = embeddings.collect()

CPU times: user 1.45 ms, sys: 1.1 ms, total: 2.56 ms
Wall time: 384 ms


In [23]:
%%time
embeddings = df.withColumn("encoding", encode(col("lines")))
results = embeddings.collect()

CPU times: user 1.63 ms, sys: 1.28 ms, total: 2.91 ms
Wall time: 416 ms


In [24]:
embeddings.show(truncate=60)

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       lines|                                                    encoding|
+------------------------------------------------------------+------------------------------------------------------------+
|This is so overly clichéd you'll want to switch it off af...|[-0.06755393, -0.1336537, 0.366753, -0.2772312, -0.085145...|
|                       I was very disappointed by this movie|[-0.059038587, 0.1668467, 0.16768396, 0.10940957, 0.18100...|
|                 I think vampire movies (usually) are wicked|[0.025601566, -0.5308643, -0.31913283, -0.013350786, -0.3...|
|Though not a complete waste of time, 'Eighteen' really wa...|[0.2099183, 0.5228606, 0.4451728, -0.031682458, -0.411756...|
|This film did well at the box office, and the producers o...|[0.1809797, -0.036222238, -0.34149715, 0.06155738, -0.066...|
|Peter C

#### Stop Triton Server on each executor

In [25]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.barrier().mapPartitions(stop_triton).collect()

                                                                                

[True]

In [26]:
spark.stop()