In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Model Garden RAG API

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_rag.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_rag.ipynb"">
      <img width="32px" src="https://cloud.google.com/ml-engine/images/colab-enterprise-logo-32px.png" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>    
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_rag.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_rag.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>

## 0. Set up the Environment and Test Project

In [1]:
!pip3 install --force-reinstall google-cloud-aiplatform "numpy<2.0.0" --user
!pip install --upgrade --quiet openai

Collecting google-cloud-aiplatform
  Downloading google_cloud_aiplatform-1.75.0-py2.py3-none-any.whl.metadata (31 kB)
Collecting numpy<2.0.0
  Downloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,<3.0.0dev,>=1.34.1 (from google-api-core[grpc]!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,<3.0.0dev,>=1.34.1->google-cloud-aiplatform)
  Downloading google_api_core-2.24.0-py3-none-any.whl.metadata (3.0 kB)
Collecting google-auth<3.0.0dev,>=2.14.1 (from google-cloud-aiplatform)
  Downloading google_auth-2.37.0-py2.py3-none-any.whl.metadata (4.8 kB)
Collecting proto-plus<2.0.0dev,>=1.22.3 (from google-cloud-aiplatform)
  Downloading proto_plus-1.25.0-py3-none-any.whl.metadata (2.2 kB)
Collecting protobuf!=4.21.0,!=4

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/454.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m450.6/454.3 kB[0m [31m13.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m454.3/454.3 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: Operation cancelled by user[0m[31m
[0m

In [1]:
from google.colab import auth

auth.authenticate_user()

# Install gcloud
!pip install google-cloud

Collecting google-cloud
  Downloading google_cloud-0.34.0-py2.py3-none-any.whl.metadata (2.7 kB)
Downloading google_cloud-0.34.0-py2.py3-none-any.whl (1.8 kB)
Installing collected packages: google-cloud
Successfully installed google-cloud-0.34.0


**Remember to restart after pip install.**

In [2]:
import sys

if "google.colab" in sys.modules:

    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

## Initialization


In [1]:
import vertexai
from vertexai.preview import rag
from vertexai.preview.generative_models import GenerativeModel, Tool

In [2]:
# Set Project
PROJECT_ID = "upsamplex5"  # @param {type:"string", "placeholder": "your-project-id"}

In [3]:
vertexai.init(project=PROJECT_ID, location="us-central1")

## Create a RAG corpus


In [9]:
rag.list_corpora()

ListRagCorporaPager<rag_corpora {
  name: "projects/upsamplex5/locations/us-central1/ragCorpora/2305843009213693952"
  display_name: "1624495570.us-central1-1036911965724.vdb.vertexai.goog"
  create_time {
    seconds: 1734621510
    nanos: 579627000
  }
  update_time {
    seconds: 1734621510
    nanos: 579627000
  }
  rag_embedding_model_config {
    vertex_prediction_endpoint {
      endpoint: "projects/1036911965724/locations/us-central1/publishers/google/models/text-embedding-004"
    }
  }
  rag_vector_db_config {
    rag_managed_db {
    }
    rag_embedding_model_config {
      vertex_prediction_endpoint {
        endpoint: "projects/1036911965724/locations/us-central1/publishers/google/models/text-embedding-004"
      }
    }
  }
  corpus_status {
    state: ACTIVE
  }
  vector_db_config {
    rag_managed_db {
    }
    rag_embedding_model_config {
      vertex_prediction_endpoint {
        endpoint: "projects/1036911965724/locations/us-central1/publishers/google/models/text-em

In [12]:
# Check the corpus just created
rag_corpus = rag.get_corpus(name="projects/upsamplex5/locations/us-central1/ragCorpora/8791026472627208192")

In [7]:
VECTOR_SEARCH_INDEX_ENDPOINT = "projects/1036911965724/locations/us-central1/indexes/8536111298797109248" # @param {type:"string", "placeholder": "your-vector-search-index-endpoint"}
VECTOR_SEARCH_INDEX = "projects/1036911965724/locations/us-central1/indexEndpoints/4465068341886713856" # @param {type:"string", "placeholder": "your-vector-search-index"}
vector_db = rag.VertexVectorSearch(
    index_endpoint=VECTOR_SEARCH_INDEX_ENDPOINT,
    index=VECTOR_SEARCH_INDEX,
)

## Upload a file to the corpus

In [None]:
%%writefile test.txt

Here's a demo for Llama3 RAG

In [None]:
rag_file = rag.upload_file(
    corpus_name=rag_corpus.name,
    path="test.txt",
    display_name="test.txt",
    description="my test",
)

## Import files from Google Cloud Storage
Remember to grant "Viewer" access to the "Vertex RAG Data Service Agent" (with the format of service-{project_number}@gcp-sa-vertex-rag.iam.gserviceaccount.com) for your Google Cloud Storage bucket

In [None]:
GS_BUCKET = ""  # @param {type:"string", "placeholder": "your-gs-bucket"}

response = await rag.import_files_async(  # noqa: F704
    corpus_name=rag_corpus.name,
    paths=[GS_BUCKET],
    chunk_size=512,
    chunk_overlap=50,
)

In [13]:
# Check the files just imported. It may take a few seconds to process the imported files.
list(rag.list_files(corpus_name=rag_corpus.name))

[name: "projects/1036911965724/locations/us-central1/ragCorpora/8791026472627208192/ragFiles/5327824635373927734"
 display_name: "Direction_13082024.json"
 create_time {
   seconds: 1734637817
   nanos: 838125000
 }
 update_time {
   seconds: 1734637817
   nanos: 838125000
 }
 gcs_source {
   uris: "gs://rag-consumer-1/Direction_13082024.json"
 }
 file_status {
   state: ACTIVE
 },
 name: "projects/1036911965724/locations/us-central1/ragCorpora/8791026472627208192/ragFiles/5327824635387522285"
 display_name: "Airtel FAQ\'s.json"
 create_time {
   seconds: 1734637817
   nanos: 844751000
 }
 update_time {
   seconds: 1734637817
   nanos: 844751000
 }
 gcs_source {
   uris: "gs://rag-consumer-1/Airtel FAQ\'s.json"
 }
 file_status {
   state: ACTIVE
 },
 name: "projects/1036911965724/locations/us-central1/ragCorpora/8791026472627208192/ragFiles/5327824662524596017"
 display_name: "TCCR_Regulation_3rd_amendmen.json"
 create_time {
   seconds: 1734637820
   nanos: 971375000
 }
 update_time {

## Using GenerateContent API with Google-operated Llama3 model endpoint

When retrieval query similarity distance < vector_distance_threshold, generate content will cite the retrieved context (from RagStore).

In [14]:
rag_resource = rag.RagResource(
    rag_corpus=rag_corpus.name,
)

rag_retrieval_tool = Tool.from_retrieval(
    retrieval=rag.Retrieval(
        source=rag.VertexRagStore(
            # Currently only 1 corpus is allowed.
            rag_resources=[rag_resource],
            similarity_top_k=10,
            vector_distance_threshold=0.4,
        ),
    )
)

In [15]:
ENDPOINT = f"projects/{PROJECT_ID}/locations/us-central1/publishers/meta/models/llama-3.1-405b-instruct-maas"

rag_model = GenerativeModel(ENDPOINT, tools=[rag_retrieval_tool])

In [None]:
GENERATE_CONTENT_PROMPT = "What is RAG and why it is helpful?"  # @param {type:"string"}

response = rag_model.generate_content(GENERATE_CONTENT_PROMPT)

In [None]:
response

## Using ChatCompletions API with Google-operated Llama3 model endpoint

Use OpenAI compatible ChatCompletions API and set Rag Retrieval Tool in the extra_body.

In [16]:
import openai
from google.auth import default, transport

credentials, _ = default()
auth_request = transport.requests.Request()
credentials.refresh(auth_request)

client = openai.OpenAI(
    base_url=f"https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/endpoints/openapi/chat/completions?",
    api_key=credentials.token,
)

In [None]:
CHAT_COMPLETIONS_PROMPT = "What is RAG and why it is helpful?"  # @param {type:"string"}

response = client.chat.completions.create(
    model="meta/llama-3.1-405b-instruct-maas",
    messages=[{"role": "user", "content": CHAT_COMPLETIONS_PROMPT}],
    extra_body={
        "extra_body": {
            "google": {
                "vertex_rag_store": {
                    "rag_resources": {"rag_corpus": rag_corpus.name},
                    "similarity_top_k": 10,
                }
            }
        }
    },
)

In [None]:
response

# Conversation

In [18]:
convo = []
def manage_conversation():
    # Initialize the conversation history
    conversation_history = [
        {"role": "system", "content": "You are a customer service assistant bot. Your job is to assist customers when they have a query or issue with anything pertaining to your context."}
    ]

    print("Bot: Hello! How can I assist you today? (Type 'exit' to end the conversation.)")

    while True:
        # Take user input
        user_input = input("You: ").strip()

        # Check if the user wants to exit
        if user_input.lower() == "exit":
            print("Bot: Thank you for chatting with me. Have a great day!")
            break

        # Append user input to conversation history
        conversation_history.append({"role": "user", "content": user_input})

        # Call the Llama chat completions API
        response = client.chat.completions.create(
            model="meta/llama-3.1-70b-instruct-maas",
            messages=conversation_history,
            extra_body={
                "extra_body": {
                    "google": {
                        "vertex_rag_store": {
                            "rag_resources": {"rag_corpus": rag_corpus.name},
                            "similarity_top_k": 5,
                        }
                    }
                }
            },
        )

        # Extract the assistant's response
        bot_response = response.choices[0].message.content + f"\nInfo from file {response.choices[0].extra_properties['google']['grounding_metadata']['grounding_chunks'][0]['retrieved_context']['uri']}"

        # Print and append bot response to conversation history
        print(f"Bot: {bot_response}")
        conversation_history.append({"role": "assistant", "content": bot_response})

        convo.append(conversation_history)

# Start the chat
manage_conversation()

Bot: Hello! How can I assist you today? (Type 'exit' to end the conversation.)
You: tell me about regulation tariffs
Bot: There is no single tariff plan which is uniformly best suited for each and every subscriber. There are a large number of tariff schemes in the market targeted at different user categories. It is essential for a subscriber to estimate his expected volume of usage and the pattern of usage and other preferences before deciding on the plan he should subscribe to.
Info from file gs://rag-consumer-1/FAQ.json
You: tell me some details about some regulation tariffs
Bot: There is no single tariff plan which is uniformly best suited for each and every subscriber. There are a large number of tariff schemes in the market targeted at different user categories. It is essential for a subscriber to estimate his expected volume of usage and the pattern of usage and other preferences before deciding on the plan he should subscribe to.

No charges are payable by the subscriber for mig

In [19]:
convo[-1]

[{'role': 'system',
  'content': 'You are a customer service assistant bot. Your job is to assist customers when they have a query or issue with anything pertaining to your context.'},
 {'role': 'user', 'content': 'tell me about regulation tariffs'},
 {'role': 'assistant',
  'content': 'There is no single tariff plan which is uniformly best suited for each and every subscriber. There are a large number of tariff schemes in the market targeted at different user categories. It is essential for a subscriber to estimate his expected volume of usage and the pattern of usage and other preferences before deciding on the plan he should subscribe to.\nInfo from file gs://rag-consumer-1/FAQ.json'},
 {'role': 'user',
  'content': 'tell me some details about some regulation tariffs'},
 {'role': 'assistant',
  'content': 'There is no single tariff plan which is uniformly best suited for each and every subscriber. There are a large number of tariff schemes in the market targeted at different user ca

# Addition to dataset

In [21]:
import json

CHAT_COMPLETIONS_PROMPT = "Please summarize the conversation between the user and the assistant. Provide a brief description of the user's queries and the assistant's responses."

# Combine the entire conversation to be used as input for the summarizer
conversation_history = ""
for message in convo[-1]:
    role = "User" if message['role'] == 'user' else "Assistant"
    conversation_history += f"{role}: {message['content']} "

# Generate a chat response using the model for summarization
response = client.chat.completions.create(
    model="meta/llama-3.1-70b-instruct-maas",
    messages=[
        {"role": "system", "content": "You are a summarizer bot. Summarize the conversation in a brief format."},
        {"role": "user", "content": CHAT_COMPLETIONS_PROMPT},
        {"role": "assistant", "content": conversation_history}  # Add the entire conversation history
    ],
    extra_body={
        "extra_body": {
            "google": {
                "vertex_rag_store": {
                    "rag_resources": {"rag_corpus": rag_corpus.name},
                    "similarity_top_k": 2,
                }
            }
        }
    },
)

# Get the model's summary
summary = response.choices[0].message.content

# Convert to JSON format
final_summary = {
    "summary": summary.strip()
}

# Convert the summary to JSON and write to a file
conversation_summary_json = json.dumps(final_summary, indent=4)

# Output the final summary as a JSON file
with open('conversation_summary.json', 'w') as f:
    f.write(conversation_summary_json)

conversation_summary = json.loads(conversation_summary_json)

# Now you can write the JSON object to a file
with open('conversation_summary.json', 'w') as f:
    json.dump(conversation_summary, f, indent=4)

# Verify by printing the output
print(conversation_summary['summary'])

Here's a brief summary of the conversation:

The user asked about regulation tariffs and the assistant explained that there is no single tariff plan that suits all subscribers. The assistant advised the user to estimate their expected volume of usage and pattern of usage before deciding on a plan.

The user then asked for more details about regulation tariffs, and the assistant provided information on the cost of migration from one tariff plan to another (no charges payable), the prohibition on service providers charging fixed charges or processing fees on exclusive talk-time top-ups, and the requirement for operators to send data usage information through SMS or USSD after every session.

The user also asked about good postpaid plans for Airtel, and the assistant provided information on Airtel's postpaid plans, including the Airtel Family Postpaid Plans.

Additionally, the user asked if service providers can charge processing fees on talk-time top-ups, and the assistant replied that s

In [22]:
with open("convo_1.json", 'w') as json_file:
    json.dump(conversation_summary, json_file, indent=4)

In [23]:
rag_file = rag.upload_file(
    corpus_name=rag_corpus.name,
    path="/content/convo_1.json",
    display_name="convo_1.json",
    description="my test",
)

## Cleaning up

Clean up resources created in this notebook.

In [None]:
delete_rag_corpus = False  # @param {type:"boolean"}
delete_bucket = False  # @param {type:"boolean"}

if delete_rag_corpus:
    rag_corpus_list = rag.list_corpora()
    for rag_corpus in rag_corpus_list:
        rag.delete_corpus(name=rag_corpus.name)

if delete_bucket:
    ! gsutil -m rm -r $GS_BUCKET

## API reference

For more details on RAG corpus/file management and detailed support, visit https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/rag-api
