In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generate Text Embeddings with Hugging face model using Apache Spark

**_NOTE_**: This notebook has been tested in the following environment:

* Python version = 3.10.13

## Overview
The example creates a similarity search on Stackoverflow questions to identify similar topics, questions and technologies being discussed. It leverages BigQuery and Dataproc Serverless for distributed prediction on Deep Learning models.

Data Engineers and Data Scientists with existing working knowledge of BigQuery and Dataproc/Spark can use this notebook to launch batch inference jobs at scale.

### Objective

In this tutorial, you learn how to use Apache Spark for batch inference/prediction and BQ for Vector Search. You also learn to use Dataproc Interactive Sessions from Jupyter Notebooks - From Vertex Workbench Instance or BQStudio/Colab Enterprise

The example uses open source stackoverflow data and open source Hugging Face model - all-MiniLM-L12-v2 to generate embeddings of text data. The similarity search index is created in BigQuery.

This tutorial uses the following Google Cloud ML services and resources:

- BQML - Vector Search

### Dataset

BigQuery public dataset - "bigquery-public-data.stackoverflow"

### Costs 

This tutorial uses billable components of Google Cloud:

* Dataproc Serverless
* BigQuery
* Vertex Workbench Instance / BQ Studio

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing),
TODO: [BigQuery pricing](https://cloud.google.com/bigquery/pricing),
and [Dataproc Serverless Pricing](https://cloud.google.com/dataproc-serverless/pricing), 
and use the [Pricing Calculator](https://cloud.google.com/products/calculator/)
to generate a cost estimate based on your projected usage.

## Before you begin

### Set up your Google Cloud project

**The following steps are required, regardless of your notebook environment.**

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.

2. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

3. [Enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).
[Enable the Dataproc API](https://console.cloud.google.com/flows/enableapi?apiid=dataproc.googleapis.com)

//Check for Dataproc Serverless
4. If you are running this notebook locally, you need to install the [Cloud SDK](https://cloud.google.com/sdk).

5. //check for networking requirements

## Setup & Installation

#### Select Dataproc Serverless Interactive Session as the Kernel for this notebook

Create a [Dataproc Interactive Session Template](https://cloud.google.com/dataproc-serverless/docs/guides/create-serverless-sessions-templates) using the network configuration specified in the link.

Once the Template is created, select the interactive template as the kernel for the notebook. This will create Dataproc Interactive Session [check here](https://console.cloud.google.com/dataproc/interactive?)

This may take a while, so please dont close the notebook.

In [1]:
!pip install sentence-transformers transformers
!pip install torchvision google-cloud-storage

Looking in indexes: https://us-python.pkg.dev/artifact-registry-python-cache/virtual-python/simple/
Collecting sentence-transformers
  Downloading https://us-python.pkg.dev/artifact-registry-python-cache/virtual-python/sentence-transformers/sentence_transformers-3.3.1-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers
  Downloading https://us-python.pkg.dev/artifact-registry-python-cache/virtual-python/transformers/transformers-4.47.1-py3-none-any.whl (10.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m89.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting tqdm (from sentence-transformers)
  Downloading https://us-python.pkg.dev/artifact-registry-python-cache/virtual-python/tqdm/tqdm-4.67.1-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.5/78.5 kB[0m [31m8.1 

In [1]:
!pip uninstall -y numpy
!pip install numpy==1.26

Found existing installation: numpy 1.26.4
Uninstalling numpy-1.26.4:
  Successfully uninstalled numpy-1.26.4
Looking in indexes: https://us-python.pkg.dev/artifact-registry-python-cache/virtual-python/simple/
Collecting numpy==1.26
  Downloading https://us-python.pkg.dev/artifact-registry-python-cache/virtual-python/numpy/numpy-1.26.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.9/17.9 MB[0m [31m59.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
deepspeed 0.14.5 requires ninja, which is not installed.
deepspeed 0.14.5 requires nvidia-ml-py, which is not installed.
ydata-profiling 0.0.dev0 requires wordcloud>=1.9.1, but you have wordcloud 0.0.0 which is incompatible.[0m[31

Due to certain dependencies between Hugging Face models, we fix the numpy version to 1.26

In [2]:
import numpy as np
np.__version__

'1.26.4'

Please donot forget to restart the kernel now!

#### Set your project ID

**If you don't know your project ID**, try the following:
* Run `gcloud config list`.
* Run `gcloud projects list`.
* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)

In [12]:
project_id = 'google.com:hadoop-cloud-dev'  # @param {type:"string"}
region = "us-central1"  # @param {type: "string"}

# Set the project id
# ! gcloud config set project {PROJECT_ID}

### Authenticate your Google Cloud account

The Cloud SDK, code and other libraries currently run as the service account identity of the Workbench Instance running this notebook.

//I dont think we need this

**- Authenticate the Cloud SDK with your credentials :**

In [2]:
# ! gcloud auth login

**- Authenticate code and libraries with your credentials :**

In [None]:
# ! gcloud auth application-default

**- Service account or other**
* See how to grant Cloud Storage permissions to your service account at https://cloud.google.com/storage/docs/gsutil/commands/iam#ch-examples.

### Import libraries

In [1]:
from pyspark.sql import SparkSession
from pyspark import SparkConf
import numpy as np
from sentence_transformers import SentenceTransformer
from google.cloud import bigquery

from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import struct, col, array, udf, lit
from pyspark.sql.types import ArrayType, FloatType, Union, Dict

  from tqdm.autonotebook import tqdm, trange


### Create Spark Session & load the data

In [2]:
spark = SparkSession.builder.appName("Embeddings")\
.getOrCreate()
sc = spark.sparkContext

24/12/18 14:41:39 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [3]:
bq_dataset = 'jvidhi_test'
stackoverflow_table_name = f'{bq_dataset}.stackoverflow_questions'
stackoverflow_index = f'{bq_dataset}.stackoverflow_index'

stackoverflow_data = spark.read.format('bigquery') \
  .option('table', 'bigquery-public-data.stackoverflow.posts_questions') \
  .load()

stackoverflow_data = stackoverflow_data.select('title')

##### Understand the data

In [5]:
print(stackoverflow_data.columns)
print(stackoverflow_data.count())

['title']
23020127


In [6]:
stackoverflow_data.show(5)

                                                                                

+--------------------+
|               title|
+--------------------+
|Html.ActionLink d...|
| Primitive recursion|
|  While vs. Do While|
|Protect ASP.NET S...|
|Difference betwee...|
+--------------------+
only showing top 5 rows



### Create batch prediction function

The model will be called and loaded within the batch predict function which will load the model in executors and run distributed inference on spark dataframe
Learn more - https://spark.apache.org/docs/3.4.3/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html

In [7]:
def predict_batch_fn():
    import torch
    from pyspark.sql.types import ArrayType, StringType
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using {} device".format(device))
    
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
    model.to(device)
    
    def predict(inputs: ArrayType(StringType())) -> np.ndarray:
        embeddings = model.encode(inputs) #size [batch_size]
        return embeddings #return (batch_size,384)
    
    return predict

In [8]:
results = predict_batch_udf(predict_batch_fn,
                          return_type=ArrayType(FloatType()),
                          batch_size=50)

In [9]:
%%time
prediction = stackoverflow_data.withColumn("embeddings", results('title'))

CPU times: user 7.34 ms, sys: 3.86 ms, total: 11.2 ms
Wall time: 49.2 ms


In [10]:
prediction.show(5)

[Stage 7:>                                                          (0 + 1) / 1]

+--------------------+--------------------+
|               title|          embeddings|
+--------------------+--------------------+
|Html.ActionLink d...|[-0.003286022, -0...|
| Primitive recursion|[-0.099161394, 0....|
|  While vs. Do While|[0.026201472, -0....|
|Protect ASP.NET S...|[-0.045438357, 0....|
|Difference betwee...|[0.08556235, -0.0...|
+--------------------+--------------------+
only showing top 5 rows



                                                                                

#### Save the dataframe as a table in BigQuery. We will create a vector index on this table

In [None]:
# the whole table fails to be inserted into BigQuery

prediction.limit(200000).write.mode('overwrite')\
.format("bigquery")\
.option("table",f"{project_id}.{stackoverflow_table_name}_embeddings")\
.option("writeMethod","direct")\
.save()

[Stage 8:>                                                        (0 + 12) / 39]

## Create Vector index in BigQuery 
https://cloud.google.com/bigquery/docs/vector-index 

In [None]:
client = bigquery.Client()

In [None]:
query = """
    CREATE VECTOR INDEX @index ON @table_name(embeddings)
    OPTIONS (index_type = 'TREE_AH', distance_type = 'EUCLIDEAN',
    tree_ah_options = '{"normalization_type": "L2"}');
"""
job_config = bigquery.QueryJobConfig(
    query_parameters=[
        bigquery.ScalarQueryParameter("index", "STRING", stackoverflow_index),
        bigquery.ArrayQueryParameter("table_name", "STRING", stackoverflow_table_name),
    ]
)
query_job = client.query(query, job_config=job_config)  # Make an API request.

#### Generate embedding of the search query which will be searched on the vector index to find similar search items

In [None]:
query_sentence = "Apache Spark on Dataproc"
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
embeddings = model.encode(query_sentence).tolist()
# embeddings

In [None]:
query = """
    SELECT * FROM
      VECTOR_SEARCH( TABLE @table,'embeddings',(select @embeddings),
        top_k => 5, options => '{"fraction_lists_to_search": 0.01}');
"""
job_config = bigquery.QueryJobConfig(
    query_parameters=[
        bigquery.ArrayQueryParameter("table", "STRING", stackoverflow_table_name),
        bigquery.ArrayQueryParameter("embeddings", "FLOAT", embeddings),
    ]
)
query_job = client.query(query, job_config=job_config)  # Make an API request.

In [None]:
for row in query_job:
    print(row[1]['highlights'])
    print(f'Distance: {row[2]}\n')

## Cleaning up

To clean up all Google Cloud resources used in this project, you can [delete the Google Cloud
project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.

Otherwise, you can delete the individual resources you created in this tutorial:

{TODO: Include commands to delete individual resources below}

In [None]:
#clean up spark session