In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Model Garden RAG API

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_rag.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_rag.ipynb"">
      <img width="32px" src="https://cloud.google.com/ml-engine/images/colab-enterprise-logo-32px.png" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>    
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_rag.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_rag.ipynb">
      <img src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>

## 0. Set up the Environment and Test Project

In [None]:
!pip3 install --force-reinstall google-cloud-aiplatform "numpy<2.0.0" --user
!pip install --upgrade --quiet openai

In [None]:
from google.colab import auth

auth.authenticate_user()

# Install gcloud
!pip install google-cloud

**Remember to restart after pip install.**

In [None]:
import sys

if "google.colab" in sys.modules:

    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

## Initialization


In [None]:
import vertexai
from vertexai.preview import rag
from vertexai.preview.generative_models import GenerativeModel, Tool

In [None]:
# Set Project
PROJECT_ID = ""  # @param {type:"string", "placeholder": "your-project-id"}

In [None]:
vertexai.init(project=PROJECT_ID, location="us-central1")

## Create a RAG corpus


In [None]:
# Configure a Google first-party embedding model
embedding_model_config = rag.EmbeddingModelConfig(
    publisher_model="publishers/google/models/text-embedding-004"
)

# Name your corpus
DISPLAY_NAME = ""  # @param {type:"string", "placeholder": "your-corpus-name"}

rag_corpus = rag.create_corpus(
    display_name=DISPLAY_NAME, embedding_model_config=embedding_model_config
)

# Use other embedding models
# Configure a third-party model or a Google fine-tuned first-party model as a Vertex Endpoint resource
# See https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_e5.ipynb
# for deploying 3P embedding models to endpoints
# EMBEDDING_MODEL_ENDPOINT_ID = "" # @param {type:"string", "placeholder": "your-model-endpoint-id"}
# EMBEDDING_MODEL_ENDPOINT = f"projects/{PROJECT_ID}/locations/us-central1/endpoints/{EMBEDDING_MODEL_ENDPOINT_ID}"
# embedding_model_config = rag.EmbeddingModelConfig(
#     endpoint=EMBEDDING_MODEL_ENDPOINT,
# )

# Use Pinecone as the Vector Database
# Configure a Pinecone Instance
# PINECONE_INDEX_NAME = "" # @param {type:"string", "placeholder": "your-pinecone-index-name"}
# PINECONE_API_KEY = "" # @param {type:"string", "placeholder": "your-secret-manager-resource-name"}
# vector_db = rag.Pinecone(
#     index_name=PINECONE_INDEX_NAME,
#     api_key=PINECONE_API_KEY,
# )
# rag_corpus = rag.create_corpus(
#     display_name=DISPLAY_NAME, embedding_model_config=embedding_model_config, vector_db=vector_db
# )

# Use Weaviate as the Vector Database
# Configure a Weaviate Vector Database Instance for the corpus
# WEAVIATE_HTTP_ENDPOINT = "" # @param {type:"string", "placeholder": "your-weaviate-http-endpoint"}
# COLLECTION_NAME = "" # @param {type:"string", "placeholder": "your-weaviate-collection-name"}
# API_KEY = "" # @param {type:"string", "placeholder": "your-secret-manager-resource-name"}
# vector_db = rag.Weaviate(
#     weaviate_http_endpoint=WEAVIATE_HTTP_ENDPOINT,
#     collection_name=COLLECTION_NAME,
#     api_key=API_KEY,
# )
# rag_corpus = rag.create_corpus(
#     display_name=DISPLAY_NAME, embedding_model_config=embedding_model_config, vector_db=vector_db
# )

# Use Vertex Feature Store as the Vector Database
# Configure a Vertex Feature Store Instance
# FEATURE_VIEW_RESOURCE_NAME = "" # @param {type:"string", "placeholder": "your-feature-view-resource-name"}
# vector_db = rag.VertexFeatureStore(
#     resource_name=FEATURE_VIEW_RESOURCE_NAME,
# )
# rag_corpus = rag.create_corpus(
#     display_name=DISPLAY_NAME, embedding_model_config=embedding_model_config, vector_db=vector_db
# )

# Use Vertex Vector Search as the Vector Database
# Configure a Vertex Vector Search instance
# VECTOR_SEARCH_INDEX_ENDPOINT = "" # @param {type:"string", "placeholder": "your-vector-search-index-endpoint"}
# VECTOR_SEARCH_INDEX = "" # @param {type:"string", "placeholder": "your-vector-search-index"}
# vector_db = rag.VertexVectorSearch(
#     index_endpoint=VECTOR_SEARCH_INDEX_ENDPOINT,
#     index=VECTOR_SEARCH_INDEX,
# )
# rag_corpus = rag.create_corpus(
#     display_name=DISPLAY_NAME, embedding_model_config=embedding_model_config, vector_db=vector_db
# )

In [None]:
# Check the corpus just created
rag.list_corpora()

## Upload a file to the corpus

In [None]:
%%writefile test.txt

Here's a demo for Llama3 RAG

In [None]:
rag_file = rag.upload_file(
    corpus_name=rag_corpus.name,
    path="test.txt",
    display_name="test.txt",
    description="my test",
)

## Import files from Google Cloud Storage
Remember to grant "Viewer" access to the "Vertex RAG Data Service Agent" (with the format of service-{project_number}@gcp-sa-vertex-rag.iam.gserviceaccount.com) for your Google Cloud Storage bucket

In [None]:
GS_BUCKET = ""  # @param {type:"string", "placeholder": "your-gs-bucket"}

response = await rag.import_files_async(  # noqa: F704
    corpus_name=rag_corpus.name,
    paths=[GS_BUCKET],
    chunk_size=512,
    chunk_overlap=50,
)

In [None]:
# Check the files just imported. It may take a few seconds to process the imported files.
list(rag.list_files(corpus_name=rag_corpus.name))

## Import files from Google Drive
Eligible paths can be https://drive.google.com/drive/folders/{folder_id} or https://drive.google.com/file/d/{file_id}.

Remember to grant "Viewer" access to the "Vertex RAG Data Service Agent" (with the format of `service-{project_number}@gcp-sa-vertex-rag.iam.gserviceaccount.com`) for your Drive folder/files.

In [None]:
FILE_ID = ""  # @param {type:"string", "placeholder": "your-file-id"}
FILE_PATH = f"https://drive.google.com/file/d/{FILE_ID}"

In [None]:
rag.import_files(
    corpus_name=rag_corpus.name,
    paths=[FILE_PATH],
    chunk_size=1024,
    chunk_overlap=100,
)

In [None]:
# Check the files just imported. It may take a few seconds to process the imported files.
list(rag.list_files(corpus_name=rag_corpus.name))

## Import files from Slack

In [None]:
CHANNEL_ID = ""  # @param {type:"string", "placeholder": "your-slack-channel-id"}
# fmt: off
API_KEY_SECRET_VERSION = ""  # @param {type:"string", "placeholder": "your-secret-manager-resource-name"}
# fmt: on

In [None]:
slack_source = rag.SlackChannelsSource(
    channels=[rag.SlackChannel(CHANNEL_ID, API_KEY_SECRET_VERSION)],
)

In [None]:
response = await rag.import_files_async(  # noqa: F704
    corpus_name=rag_corpus.name,
    source=slack_source,
    chunk_size=1024,
    chunk_overlap=200,
)

In [None]:
# Check the files just imported. It may take a few seconds to process the imported files.
list(rag.list_files(corpus_name=rag_corpus.name))

## Import files from Jira

In [None]:
EMAIL = ""  # @param {type:"string", "placeholder": "your-email"}
SERVER_URI = ""  # @param {type:"string", "placeholder": "your-server.atlassian.net"}
PROJECT = ""  # @param {type:"string", "placeholder": "your-project-name"}
CUSTOM_QUERY = ""  # @param {type:"string", "placeholder": "your-custom-jql-query"}
# fmt: off
API_KEY_SECRET_VERSION = ""  # @param {type:"string", "placeholder": "your-secret-manager-resource-name"}
# fmt: on

In [None]:
jira_query = rag.JiraQuery(
    email=EMAIL,
    jira_projects=[PROJECT],
    custom_queries=[CUSTOM_QUERY],
    api_key=API_KEY_SECRET_VERSION,
    server_uri=SERVER_URI,
)

jira_source = rag.JiraSource(
    queries=[jira_query],
)

In [None]:
response = await rag.import_files_async(  # noqa: F704
    corpus_name=rag_corpus.name,
    source=jira_source,
    chunk_size=1024,
    chunk_overlap=200,
)

In [None]:
# Check the files just imported. It may take a few seconds to process the imported files.
list(rag.list_files(corpus_name=rag_corpus.name))

## Using GenerateContent API with Google-operated Llama3 model endpoint

When retrieval query similarity distance < vector_distance_threshold, generate content will cite the retrieved context (from RagStore).

In [None]:
rag_resource = rag.RagResource(
    rag_corpus=rag_corpus.name,
)

rag_retrieval_tool = Tool.from_retrieval(
    retrieval=rag.Retrieval(
        source=rag.VertexRagStore(
            # Currently only 1 corpus is allowed.
            rag_resources=[rag_resource],
            similarity_top_k=10,
            vector_distance_threshold=0.4,
        ),
    )
)

In [None]:
ENDPOINT = f"projects/{PROJECT_ID}/locations/us-central1/publishers/meta/models/llama-3.1-405b-instruct-maas"

rag_model = GenerativeModel(ENDPOINT, tools=[rag_retrieval_tool])

In [None]:
GENERATE_CONTENT_PROMPT = "What is RAG and why it is helpful?"  # @param {type:"string"}

response = rag_model.generate_content(GENERATE_CONTENT_PROMPT)

In [None]:
response

## Using GenerateContent API with self-deployed Llama3 model endpoint

When retrieval query similarity distance < vector_distance_threshold, generate content will cite the retrieved context (from RagStore).


In [None]:
rag_resource = rag.RagResource(
    rag_corpus=rag_corpus.name,
)

rag_retrieval_tool = Tool.from_retrieval(
    retrieval=rag.Retrieval(
        source=rag.VertexRagStore(
            rag_resources=[rag_resource],  # Currently only 1 corpus is allowed.
            similarity_top_k=10,
            vector_distance_threshold=0.4,
        ),
    )
)

In [None]:
ENDPOINT_ID = ""  # @param {type:"string", "placeholder": "your-endpoint-id"}
ENDPOINT = f"projects/{PROJECT_ID}/locations/us-central1/endpoints/{ENDPOINT_ID}"

rag_model = GenerativeModel(ENDPOINT, tools=[rag_retrieval_tool])

In [None]:
GENERATE_CONTENT_PROMPT = "What is RAG and why it is helpful?"  # @param {type:"string"}

response = rag_model.generate_content(GENERATE_CONTENT_PROMPT)

In [None]:
response

## Using ChatCompletions API with Google-operated Llama3 model endpoint

Use OpenAI compatible ChatCompletions API and set Rag Retrieval Tool in the extra_body.

In [None]:
import openai
from google.auth import default, transport

credentials, _ = default()
auth_request = transport.requests.Request()
credentials.refresh(auth_request)

client = openai.OpenAI(
    base_url=f"https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/endpoints/openapi/chat/completions?",
    api_key=credentials.token,
)

In [None]:
CHAT_COMPLETIONS_PROMPT = "What is RAG and why it is helpful?"  # @param {type:"string"}

response = client.chat.completions.create(
    model="meta/llama-3.1-405b-instruct-maas",
    messages=[{"role": "user", "content": CHAT_COMPLETIONS_PROMPT}],
    extra_body={
        "extra_body": {
            "google": {
                "vertex_rag_store": {
                    "rag_resources": {"rag_corpus": rag_corpus.name},
                    "similarity_top_k": 10,
                }
            }
        }
    },
)

In [None]:
response

## Using other generation API with Llama3 model endpoint

The retrieved contexts can be passed to any SDK or model generation API to generate final results.


In [None]:
RETRIEVAL_QUERY = "What is RAG and why it is helpful?"  # @param {type:"string"}

rag_resource = rag.RagResource(
    rag_corpus=rag_corpus.name,
)

response = rag.retrieval_query(
    rag_resources=[rag_resource],  # Currently only 1 corpus is allowed.
    text=RETRIEVAL_QUERY,
    similarity_top_k=10,
    vector_distance_threshold=0.4,
)

# The retrieved context can be passed to any SDK or model generation API to generate final results.
retrieved_context = " ".join(
    [context.text for context in response.contexts.contexts]
).replace("\n", "")

In [None]:
retrieved_context

## Cleaning up

Clean up resources created in this notebook.

In [None]:
delete_rag_corpus = False  # @param {type:"boolean"}
delete_bucket = False  # @param {type:"boolean"}

if delete_rag_corpus:
    rag_corpus_list = rag.list_corpora()
    for rag_corpus in rag_corpus_list:
        rag.delete_corpus(name=rag_corpus.name)

if delete_bucket:
    ! gsutil -m rm -r $GS_BUCKET

## API reference

For more details on RAG corpus/file management and detailed support, visit https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/rag-api
