diff --git a/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/config.ini b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/config.ini new file mode 100644 index 000000000..7476e5319 --- /dev/null +++ b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/config.ini @@ -0,0 +1,41 @@ +[default] +project_id = genai-github-assets +region = us-central1 +enabled_programming_language = python +default_language = python + +[vector_search] +embedding_qpm = 100 +embedding_num_batch = 5 +me_region = us-central1 +me_index_name = genai-github-assets-index +me_gcs_bucket_region = us-central1 +me_gcs_bucket = +me_dimensions = 768 +embedding_model_name = textembedding-gecko +split_document_method = CHUNKS +chunk_size = 5000 +chunk_overlap = 100 +embedding_jsonl_file = data.jsonl +embedding_csv_file = embedding_df.csv + +[genai_chat] +model_name = gemini-1.0-pro +temperature = 0 +max_output_tokens = 1024 + +[genai_qna] +model_name = gemini-1.0-pro +max_output_tokens = 1024 +temperature = 0.3 +top_p = 0.8 +top_k = 40 +number_of_references_to_summarise = 6 + +[error_msg] +non_programming_question_error_msg = I apologize, I am configured to answer only in these programming languages Python programming. Specify the programming language in your query to get more accurate and helpful answers. +non_qna_programming_question_error_msg = I apologize, I am configured to answer only in these programming languages Python. Specify the programming language in your query to get more accurate and helpful answers. +unable_to_understand_question = I apologize, but I am not able to understand the question. Please try to elaborate and rephrase your question. +other_intent_error_msg = I apologize, I am allowed to answer programming related questions only. +no_reference_error_msg = I could not find any references that are directly related to your question in the knowledgebase. Please try to elaborate and rephrase your question. + diff --git a/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/qna_using_query_routing.ipynb b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/qna_using_query_routing.ipynb new file mode 100644 index 000000000..b22a4431e --- /dev/null +++ b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/qna_using_query_routing.ipynb @@ -0,0 +1,1008 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HNF3BjHL0x_D" + }, + "outputs": [], + "source": [ + "# Copyright 2024 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "43291490b636" + }, + "source": [ + "# RAG - QnA using Query Routing\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \"Google
Open in Colab\n", + "
\n", + "
\n", + " \n", + " \"Google
Open in Colab Enterprise\n", + "
\n", + "
\n", + " \n", + " \"Vertex
Open in Workbench\n", + "
\n", + "
\n", + " \n", + " \"GitHub
View on GitHub\n", + "
\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tF69osTL1aaM" + }, + "source": [ + "| | |\n", + "|-|-|\n", + "|Author(s) | [Charu Shelar](https://github.com/CharulataShelar) |" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tvgnzT1CKxrO" + }, + "source": [ + "## Overview\n", + "\n", + "This notebook showcases the use of query routing techniques to improve retrieval performance in an AI-powered learning assistant for a computer training institute. This assistant is designed to use LLM to classify the intent of the user query, which in turn determines the appropriate source(s) to answer the query. The solution has been built using the custom RAG approach and Gemini model (`Gemini Pro 1.0`)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Folder Structure\n", + "\n", + "1. qna_using_query_routing/\n", + " - config.ini : Configuration file.\n", + " - qna_using_query_routing.ipynb: Main demo notebook.\n", + "\n", + "2. utils/\n", + " - intent_routing.py : Contains methods for intent classification and route the request to respective componets.\n", + " - qna_vector_search.py : Answer QnA Type Questions using indexed documents.\n", + " - qna_using_query_routing_utils.py : Contains other utility functions.\n", + "\n", + "3. images/\n", + " - This folder contains images used in the notebook." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "96fad856c458" + }, + "source": [ + "## As a developer, you will learn the following steps to implement the solution\n", + "\n", + "1. Embed the document and create a vector search index using Vector Search (previously known as Matching Engine).\n", + " - Create document embeddings and upload them to GCS bucket.\n", + " - Create/update the vector search index using the embeddings.\n", + " - Python code function used here: `utils.qna_using_query_routing_utils.create_vector_search_index()`\n", + "\n", + "2. Build RAG (Retrieval-Augmented Generation) for intra document search using routing.\n", + " - Identify the intent of the user query and route the query.\n", + " - Answer programming questions using indexed documents i.e. using Vector Search's semantic search.\n", + " - Answer coding questions using the Gemini model if the knowledge base does not have the relevant context/content.\n", + " - To prevent hallucinations and maintain appropriate responses, the solution demonstrates how to guardrail the system's response to predetermined programming languages when handling user queries. The config.ini file can be used to configure the list of supported programming languages.\n", + " - Python code files used here: `utils/intent_routing.py` and `utils/qna_vector_search.py`\n", + "\n", + "3. To build chat UI interface using Gradio\n", + " - Create a chat interface to allow users to interact with the virtual assistant\n", + " - Create a separate tab on the UI to allow end users to index new documents from the GCS bucket" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9b8bfd66ab54" + }, + "source": [ + "## Google Cloud Services used\n", + "\n", + "1. Vector Search (previously Matching Engine)\n", + "2. Large Language Models - Gemini Pro 1.0, textembedding-gecko\n", + "\n", + "## Costs\n", + "This tutorial uses billable components of Google Cloud:\n", + "- Vector Search\n", + "- Gemini Pro 1.0\n", + "- textembedding-gecko\n", + "\n", + "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and use the [Pricing Calculator](https://cloud.google.com/products/calculator/?hl=en) to generate a cost estimate based on your projected usage." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "88bdc5fac612" + }, + "source": [ + "## Solution Design Flow\n", + "\n", + "![genAI Asset Learning assistant](https://storage.googleapis.com/github-repo/generative-ai/gemini/use-cases/rag/qna-using-query-routing/architecture.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ab50bce05450" + }, + "source": [ + "### Response Generation using Query Routing:\n", + "1. The user starts a natural language query through a Gradio Chat User Interface (UI).\n", + "\n", + "2. Intent classification is done using Gemini model. It classifies the message into one of the following intents: `WELCOME`, `PROGRAMMING_QUESTION_AND_ANSWER`, `WRITE_CODE`, `FOLLOWUP`, or `CLOSE`.\n", + "\n", + "3. For the `WRITE_CODE` intent, the Gemini model is used to generate code using its coding capability.\n", + "\n", + "4. For the `PROGRAMMING_QUESTION_AND_ANSWER` intent, custom orchestration (RAG) retrieves context relevant to the user query from Vector Search and summarises relavent contexts. If the answer is not found, the user query is routed to the Gemini Model to respond using its knowledge.\n", + "\n", + "5. For the `FOLLOWUP` intent, such as explaining more or writing code for previous responses, the Gemini Model is used to generate responses using its code capability.\n", + "\n", + "6. For the `WELCOME` and `CLOSE` intents, the Gemini model is used to generate appropriate responses." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "61RBz8LLbxCR" + }, + "source": [ + "## Getting Started" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "No17Cw5hgx12" + }, + "source": [ + "### Install Vertex AI SDK and other required packages\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tFy3H3aPgx12" + }, + "outputs": [], + "source": [ + "!pip3 install --upgrade --user google-cloud-aiplatform \\\n", + "langchain==0.1.13 \\\n", + "pypdf==4.1.0 \\\n", + "gradio==3.41.2 \\\n", + "langchain-google-vertexai \\\n", + "--quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R5Xep4W9lq-Z" + }, + "source": [ + "### Restart runtime (Colab only)\n", + "\n", + "To use the newly installed packages, you must restart the runtime on Google Colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XRvKdaPDTznN" + }, + "outputs": [], + "source": [ + "import IPython\n", + "\n", + "app = IPython.Application.instance()\n", + "app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SbmM4z7FOBpM" + }, + "source": [ + "
\n", + "⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dmWOrTJ3gx13" + }, + "source": [ + "### Authenticate your notebook environment (Colab only)\n", + "\n", + "Authenticate your environment on Google Colab.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NyKGtVQjgx13" + }, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "if \"google.colab\" in sys.modules:\n", + " from google.colab import auth\n", + "\n", + " auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c0d78ca17444" + }, + "source": [ + "### Import required packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7cae7793b5c8", + "tags": [] + }, + "outputs": [], + "source": [ + "import configparser\n", + "import logging\n", + "import os\n", + "import uuid\n", + "import pandas as pd\n", + "from datetime import datetime\n", + "\n", + "import gradio as gr\n", + "import vertexai\n", + "from vertexai.generative_models import GenerativeModel\n", + "from vertexai.preview.language_models import TextEmbeddingModel\n", + "\n", + "from utils import qna_using_query_routing_utils\n", + "from utils.intent_routing import IntentRouting" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DF4l8DTdWgPY" + }, + "source": [ + "### Set Google Cloud project information and initialize Vertex AI SDK" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nqwi-5ufWp_B", + "tags": [] + }, + "outputs": [], + "source": [ + "PROJECT_ID = \"your-project-id\" # @param {type:\"string\"}\n", + "LOCATION = \"us-central1\" # @param {type:\"string\"}\n", + "\n", + "vertexai.init(project=PROJECT_ID, location=LOCATION)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "406eeb47322f" + }, + "source": [ + "Set up logging for the application" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e6713f15d223", + "tags": [] + }, + "outputs": [], + "source": [ + "logging.basicConfig(level=logging.INFO)\n", + "logger = logging.getLogger(__name__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f96bfd65600d" + }, + "source": [ + "### Update the project settings in config file" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0f4e674fefb2" + }, + "source": [ + "
\n", + "⚠️ Please do not change the configuration file name i.e. `config.ini` ⚠️\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9c18fe9c483f", + "tags": [] + }, + "outputs": [], + "source": [ + "config_file = \"config.ini\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5b7f3f56c323" + }, + "source": [ + "#### Update the settings in the config file\n", + "\n", + "**Note:** Some settings in the `config.ini` file are are updated from this notebook. \n", + "Additional parameters can be modified manually or using same code." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "088b31c13ccb", + "tags": [] + }, + "outputs": [], + "source": [ + "config = configparser.ConfigParser()\n", + "config.read(config_file)\n", + "\n", + "config.set(\"default\", \"project_id\", PROJECT_ID)\n", + "config.set(\"default\", \"region\", LOCATION)\n", + "\n", + "with open(config_file, \"w\") as cf:\n", + " config.write(cf)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "efaf63dffa03" + }, + "source": [ + "### [One-time] Setup Vector Search for QnA\n", + "\n", + "- Download sample pdf document and save it in `DOCUMENT_FOLDER`\n", + "- Generate document embeddings, this will split and chunk the documents as configured using `chunk_size` and `chunk_overlap` in the `config.init` file\n", + "- Setup a Vector Search index (create vector search index, endpoint and deploy the index to the endpoint)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3ab5b312d0a5" + }, + "source": [ + "#### Download sample pdf document and save it in `DOCUMENT_FOLDER`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d6685c9c2493" + }, + "outputs": [], + "source": [ + "# Download the sample document\n", + "!wget https://cfm.ehu.es/ricardo/docs/python/Learning_Python.pdf\n", + "\n", + "DOCUMENT_FOLDER = \"document\"\n", + "\n", + "# Create a \"document\" directory if it doesn't exist\n", + "if not os.path.exists(DOCUMENT_FOLDER):\n", + " os.makedirs(DOCUMENT_FOLDER)\n", + "\n", + "# Move the document to `DOCUMENT_FOLDER` folder\n", + "!mv Learning_Python.pdf {DOCUMENT_FOLDER}/Learning_Python.pdf" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5828a9525d4a" + }, + "source": [ + "#### Generate document embeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Read document(s) and split into chunks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7ec8e761e587", + "tags": [] + }, + "outputs": [], + "source": [ + "doc_splits = qna_using_query_routing_utils.get_split_documents(DOCUMENT_FOLDER)\n", + "\n", + "for idx, split in enumerate(doc_splits):\n", + " split.metadata[\"chunk\"] = idx\n", + "\n", + "# Log the number of documents after splitting\n", + "print(f\"Number of chunks = {len(doc_splits)}\")\n", + "\n", + "data = [{\"content\": ob.page_content, \"metadata\": ob.metadata} for ob in doc_splits]\n", + "\n", + "local_embedding_filename = config[\"vector_search\"][\"embedding_jsonl_file\"]\n", + "\n", + "with open(local_embedding_filename, \"w\") as outfile:\n", + " for item in data:\n", + " json_line = json.dumps(item)\n", + " outfile.write(json_line + \"\\n\")\n", + "\n", + "print(\"Saving document chunks in json file:\", local_embedding_filename)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Update the BUCKET_URI in config.ini file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "UID = datetime.now().strftime(\"%m%d%H%M\")\n", + "BUCKET_URI = f'gs://{config[\"default\"][\"project_id\"]}-embedding'\n", + "\n", + "config.set(\"vector_search\", \"me_gcs_bucket\", BUCKET_URI)\n", + "\n", + "with open(config_file, \"w\") as cf:\n", + " config.write(cf)\n", + "\n", + "print(\"UID :\", UID)\n", + "print(\"BUCKET_URI :\", BUCKET_URI)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create a GCS bucket and move the json file in same bucket" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "! gsutil mb -l $LOCATION -p {PROJECT_ID} $BUCKET_URI\n", + "! gsutil cp $local_embedding_filename $BUCKET_URI" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create embeddings for all the text chunks in a batch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "textembedding_model = TextEmbeddingModel.from_pretrained(\n", + " config[\"vector_search\"][\"embedding_model_name\"]\n", + ")\n", + "batch_prediction_job = textembedding_model.batch_predict(\n", + " dataset=[f\"{BUCKET_URI}/{local_embedding_filename}\"],\n", + " destination_uri_prefix=f\"{BUCKET_URI}/vertex-LLM-Batch-Prediction/{UID}\",\n", + ")\n", + "print(batch_prediction_job.display_name)\n", + "print(batch_prediction_job.resource_name)\n", + "print(batch_prediction_job.state)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "Download the embeddings file (jsonl) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "! gsutil cp -r $batch_prediction_job.gca_resource.output_info.gcs_output_directory json_files/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "296bdb5b7232" + }, + "source": [ + "#### Setup Vector Search index\n", + "\n", + "1. Create Vector Search index and Endpoint for Retrieval\n", + "2. Create and add document embeddings to Vector Store" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set display name for the vector search index" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c179b4ddae1f" + }, + "outputs": [], + "source": [ + "me_index_name = f\"{PROJECT_ID}-index\" # @param {type:\"string\"}\n", + "me_region = \"us-central1\" # @param {type:\"string\"}\n", + "\n", + "config.set(\"vector_search\", \"me_index_name\", me_index_name)\n", + "config.set(\"vector_search\", \"me_region\", me_region)\n", + "\n", + "with open(config_file, \"w\") as cf:\n", + " config.write(cf)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6097b0f89377" + }, + "source": [ + "Save embeddings from jsonl to json file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2f2234ab9572" + }, + "outputs": [], + "source": [ + "jsonl_file_path = \"json_files/000000000000.jsonl\"\n", + "\n", + "df_list = []\n", + "with open(\"create_index.json\", \"w\") as outfile:\n", + " with open(jsonl_file_path, \"r\") as infile:\n", + " for line in infile:\n", + " df_l = json.loads(line)\n", + " line_details = {\n", + " \"id\": df_l[\"instance\"][\"metadata\"][\"chunk\"],\n", + " \"embedding\": df_l[\"predictions\"][0][\"embeddings\"][\"values\"],\n", + " }\n", + " json_line = json.dumps(line_details)\n", + " outfile.write(json_line + \"\\n\")\n", + "\n", + " df_list.append(\n", + " pd.DataFrame(\n", + " {\n", + " \"id\": df_l[\"instance\"][\"metadata\"][\"chunk\"],\n", + " # \"embedding\": str(df_l[\"predictions\"][0][\"embeddings\"][\"values\"]),\n", + " \"page_source\": df_l[\"instance\"][\"metadata\"][\"source\"],\n", + " \"text\": df_l[\"instance\"][\"content\"],\n", + " },\n", + " index=[0],\n", + " )\n", + " )\n", + "\n", + "embeddings = pd.concat(df_list, ignore_index=True)\n", + "embeddings.to_csv(config[\"vector_search\"][\"embedding_csv_file\"], index=False)\n", + "print(\"Embeddings data:\", embeddings.shape)\n", + "embeddings.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "43e6baa6849f" + }, + "source": [ + "Move the embedings json file to GCS bucket" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bf8bc75b0cbb" + }, + "outputs": [], + "source": [ + "! gsutil cp create_index.json gs://genai-github-assets-embedding/create_index_json/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "72b6e23c2f72" + }, + "source": [ + "Create new vector search index and deploy it to a endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "17a027856dee" + }, + "outputs": [], + "source": [ + "(\n", + " index,\n", + " index_endpoint,\n", + " deployed_index_id,\n", + ") = qna_using_query_routing_utils.create_vector_search_index(BUCKET_URI)\n", + "print(\"Index :\", index)\n", + "print(\"Index endpoint :\", index_endpoint)\n", + "print(\"Deployed Index Id :\", deployed_index_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3afe6b737ae8" + }, + "source": [ + "**Note:** If you are re-running the code in this notebook and want to reuse the previously created index without creating a new one, execute the code in the cell below.\n", + "\n", + "This cell assumes that the index has already been created, deployed, and can be used for retrieval." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "306a79627a30" + }, + "outputs": [], + "source": [ + "# # Get details for already deployed index\n", + "# me_index_name = config[\"vector_search\"][\"me_index_name\"]\n", + "# me_region = config[\"vector_search\"][\"me_region\"]\n", + "\n", + "# index_endpoint, deployed_index_id = qna_using_query_routing_utils.get_deployed_index_id(me_index_name, me_region)\n", + "# print(\"Index endpoint :\", index_endpoint)\n", + "# print(\"Deployed Index Id :\", deployed_index_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EdvJRUWRNGHE" + }, + "source": [ + "## Chat interface using gradio app" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7fc5a97ffc91" + }, + "source": [ + "#### Mount Learning Assistant app with gradio \n", + "Now, Lets write a chat interface using Gradio that has two elements:\n", + "\n", + "**Chatbot:**\n", + "`chatbot` element has chat UI, and it shows messages from both user and chatbot(virtual assistant we buit)\n", + "\n", + "**Textbox:**\n", + "'msg` element allow gradio UI to accept the input text from the user. This text is passed to the system to fetch the answer either from Vector search or Gemini model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6ec233244b7f" + }, + "outputs": [], + "source": [ + "# Load language models for QnA and conversational interaction\n", + "model = GenerativeModel(config[\"genai_qna\"][\"model_name\"])\n", + "\n", + "# Initialize core components using configuration settings\n", + "genai_assistant = IntentRouting(\n", + " model=model,\n", + " index_endpoint=index_endpoint,\n", + " deployed_index_id=deployed_index_id,\n", + " config_file=config_file,\n", + " logger=logger,\n", + ")\n", + "\n", + "# Start the chat session and provide initial instructions to the chatbot\n", + "default_programming_language = config[\"default\"][\"default_language\"]\n", + "chat = model.start_chat(history=[])\n", + "_ = chat.send_message(\n", + " f\"\"\"You are a Programming Language Learning Assistant.\n", + " Your task is to undersand the question and respond with the descriptive answer for the same.\n", + "\n", + " Instructions:\n", + " 1. If programming language is not mentioned, then use {default_programming_language} as default programming language to write a code.\n", + " 2. Strictly follow the instructions mentioned in the question.\n", + " 3. If the question is not clear then you can answer \"I apologize, but I am not able to understand the question. Please try to elaborate and rephrase your question.\"\n", + "\n", + " If the question is about other programming language then DO NOT provide any answer, just say \"I apologize, but I am not able to understand the question. Please try to elaborate and rephrase your question.\"\n", + "\"\"\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4051668710c8" + }, + "outputs": [], + "source": [ + "def respond(message, chat_history):\n", + " \"\"\"Handles user input within a Gradio chatbot interface.\"\"\"\n", + " (response, intent) = genai_assistant.classify_intent(\n", + " message,\n", + " session_state,\n", + " )\n", + "\n", + " # append response to history\n", + " chat_history.append((message, response))\n", + "\n", + " return \"\", chat_history" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "28636e373e78" + }, + "outputs": [], + "source": [ + "gr.close_all() # Ensure a clean Gradio interface\n", + "\n", + "with gr.Blocks() as demo:\n", + " with gr.Tab(\"Learning Assistant\"):\n", + " # Welcome message for the chatbot\n", + " bot_message = \"Hi there! I'm Generative AI powered Learning Assistant. I can help you with coding tasks, answer questions, and generate code. Just ask me anything you need, and I'll do my best to help!\" # pylint: disable=C0301:line-too-long\n", + "\n", + " # Generate a unique session identifier\n", + " session_state = str(uuid.uuid4())\n", + " logger.info(\"session_state : %s\", session_state)\n", + "\n", + " # Configure the chatbot's appearance using Chatbot element\n", + " chatbot = gr.Chatbot(\n", + " height=600,\n", + " label=\"\", # No display label for the chatbot\n", + " value=[[None, bot_message]], # Initialize with the welcome message\n", + " avatar_images=(\n", + " None,\n", + " \"https://fonts.gstatic.com/s/i/short-term/release/googlesymbols/smart_assistant/default/24px.svg\",\n", + " ), # Assistant avatar\n", + " elem_classes=\"message\",\n", + " show_label=False,\n", + " )\n", + "\n", + " # Configure the textbox for the user to enter questions.\n", + " msg = gr.Textbox(\n", + " scale=4,\n", + " label=\"\",\n", + " placeholder=\"Enter your question here..\",\n", + " elem_classes=[\"form\", \"message-row\"],\n", + " )\n", + "\n", + " # Event handling, Link the `respond` function to the textbox, enabling interaction\n", + " msg.submit(fn=respond, inputs=[msg, chatbot], outputs=[msg, chatbot])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c4fff6052d41" + }, + "source": [ + "#### Launch the gradio app to view the chatbot\n", + "\n", + "**Note:**\n", + "1. For a better experience, Open the demo application interface in a new tab by clicking on the Localhost url generated after running this cell.\n", + "2. For debugging mode, set `debug=True`\n", + "\n", + "\n", + "**Example Questions to try on UI**\n", + "1. Where can we use python programming language?\n", + "2. What is the difference between list and set?\n", + "3. Fix the error in below code:\n", + "\n", + "```\n", + "def create_dataset(id: str): -> None\n", + "...\n", + "\n", + "SyntaxError: invalid syntax\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5b32aa361209" + }, + "outputs": [], + "source": [ + "demo.launch(share=True, debug=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "35553c14b7f8" + }, + "source": [ + "### Close the demo\n", + "\n", + "**Note:** Stop the previous cell to close the Gradio server running, then run this cell to free up the port utilised for running the server." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "704dbb0c853b" + }, + "outputs": [], + "source": [ + "demo.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2d70a97b2c91" + }, + "source": [ + "### Cleaning up\n", + "To clean up all Google Cloud resources used in this project, you can delete the Google Cloud project you used for the tutorial.\n", + "\n", + "Otherwise, you can delete the individual resources you created in this tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c2aaf8a38ef6" + }, + "outputs": [], + "source": [ + "delete_bucket = False\n", + "\n", + "# Force undeployment of indexes and delete endpoint\n", + "index_endpoint.delete(force=True)\n", + "\n", + "# Delete indexes\n", + "index.delete()\n", + "\n", + "if delete_bucket:\n", + " ! gsutil rm -rf {BUCKET_URI}" + ] + } + ], + "metadata": { + "colab": { + "name": "qna_using_query_routing.ipynb", + "toc_visible": true + }, + "environment": { + "kernel": "conda-root-py", + "name": "workbench-notebooks.m119", + "type": "gcloud", + "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m119" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel) (Local)", + "language": "python", + "name": "conda-root-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/__init__.py b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/__init__.py new file mode 100644 index 000000000..d799a482e --- /dev/null +++ b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/__init__.py @@ -0,0 +1,3 @@ +# Copyright 2023 Google LLC. This software is provided as-is, without warranty +# or representation for any use or purpose. Your use of it is subject to your +# agreement with Google. diff --git a/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/intent_routing.py b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/intent_routing.py new file mode 100644 index 000000000..e80fb27f3 --- /dev/null +++ b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/intent_routing.py @@ -0,0 +1,413 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generative AI Models based functions to respond questions""" + +import configparser + +# Utils +import json +import logging +from typing import List + +from google.cloud import aiplatform +from utils.qna_vector_search import QnAVectorSearch +from vertexai.generative_models import GenerationConfig, GenerativeModel + + +class IntentRouting: + """genai Assistant""" + + def __init__( + self, + model: GenerativeModel, + index_endpoint: aiplatform.MatchingEngineIndexEndpoint, + deployed_index_id: str, + config_file: str = "config.ini", + logger=logging.getLogger(), + ) -> None: + self.config = configparser.ConfigParser() + self.config.read(config_file) + + self.logger = logger + self.model = model + self.index_endpoint = index_endpoint + self.deployed_index_id = deployed_index_id + + self.genai_qna = QnAVectorSearch( + model=model, + index_endpoint=index_endpoint, + deployed_index_id=deployed_index_id, + config_file=config_file, + logger=logger, + ) + + self.genai_qna_parameters = GenerationConfig( + temperature=float(self.config["genai_qna"]["temperature"]), + max_output_tokens=int(self.config["genai_qna"]["max_output_tokens"]), + top_p=float(self.config["genai_qna"]["top_p"]), + top_k=int(self.config["genai_qna"]["top_k"]), + ) + + self.genai_chat_parameters = GenerationConfig( + temperature=float(self.config["genai_chat"]["temperature"]), + max_output_tokens=int(self.config["genai_chat"]["max_output_tokens"]), + ) + + def greetings(self, text: str) -> str: + """ + Generates a friendly greeting message in response to a 'WELCOME' intent. + + Leverages the provided chat model to generate a greeting tailored to the enabled programming languages. + + Args: + text (str): The user's original message. + + Returns: + str: The generated greeting message. + """ + + enabled_programming_language_list = self.config["default"][ + "enabled_programming_language" + ].split(",") + enabled_programming_language = ", ".join( + [lang.title() for lang in enabled_programming_language_list] + ) + + chat = self.model.start_chat() + response = chat.send_message( + f"""You are Generative AI powered genai Learning Assistant. + + You are trained ONLY to answers questions related to following programming languages: {enabled_programming_language} + + Write a brief greeting message:""" + ) + parameters_local = GenerationConfig( + temperature=0.7, + max_output_tokens=int(self.config["genai_chat"]["max_output_tokens"]), + ) + + response = chat.send_message(f"""{text}""", generation_config=parameters_local) + message = response.text + return message + + def closing(self, text: str) -> str: + """ + Generates a closing/thank you message in response to a 'CLOSE' intent. + + Leverages the provided chat model to generate a closing message. + + Args: + text (str): The user's original message. + + Returns: + str: The generated closing message. + """ + + parameters_local = GenerationConfig( + temperature=0.7, + max_output_tokens=int(self.config["genai_chat"]["max_output_tokens"]), + ) + chat = self.model.start_chat() + response = chat.send_message( + """You are Generative AI powered genai Learning Assistant. + Write a brief closing thank you message:""" + ) + response = chat.send_message(f"""{text}""", generation_config=parameters_local) + message = response.text + return message + + def genai_classify_intent(self, text: str) -> str: + """ + Classifies the intent of an incoming message using a strict intent classifier model. + + The supported intents are: + * 'WELCOME' + * 'WRITE_CODE' + * 'PROGRAMMING_QUESTION_AND_ANSWER' + * 'CLOSE' + * 'OTHER' + * 'FOLLOWUP' + + Args: + text (str): The user's message. + + Returns: + str: The classified intent. + """ + + response = self.model.generate_content( + f""" + You are strict intent classifier , Classify intent of messages into 5 categories + + Instructions: + 1. Classify the intents into one of these categories only: 'WELCOME', 'WRITE_CODE', 'PROGRAMMING_QUESTION_AND_ANSWER', 'CLOSE', 'OTHER', 'FOLLOWUP'. + 2. Messages can be read as case-insensitive. + 3. Reply ONLY with category of intent. Don't generate additional intent outside of this list. + + Definition of Intents: + 1. WELCOME : is the category with greeting message and to know about the assistant for example hi, hey there, Hello, Good morning, good afternoon, good evening, who are you?, what prgramming languesges do you know?, what do you do?, how can you help me?. + 2. WRITE_CODE : is the category with code writing , debugging, explain code message. user wants you to write a code. + 3. PROGRAMMING_QUESTION_AND_ANSWER: is the category with strictly programming language related descriptive or theoretical questions. Any other question non related to programming should go into others. + 4. CLOSE : is the category for closing the chat with messages like okay THANKS!, bye, Thanks, thank you, goodbye. + 5. OTHER : is the category where user is asking non information technology related quesiontion. for example Who is PM of india, what happended in G20 summit. + 6. FOLLOWUP : is the category with user is asking the followup questions like, such as write a code for same, give me the code for above, give me example, explain in detail, what above method is doing etc.. + + Examples: + MESSAGE: Hi + INTENT: WELCOME + + MESSAGE: What is the difference between list and set? + INTENT: PROGRAMMING_QUESTION_AND_ANSWER + + MESSAGE: Fix the error in below code: + def create_dataset(id: str): -> None + ... + + SyntaxError: invalid syntax + INTENT: WRITE_CODE + + What is the intent of the below message? + MESSAGE:{text} + INTENT:""", # pylint: disable=C0301:line-too-long + generation_config=self.genai_qna_parameters, + ) + + if response.to_dict()["candidates"][0]["finish_reason"] != 1: + self.logger.info( + "classify_intent: No response from QnA due to LLM safety checks." + ) + self.logger.info("LLM error code: %s\n", response.raw_prediction_response) + + intent = response.text + return ( + str(intent).replace('"', "").replace("'", "").replace("INTENT:", "").strip() + ) + + def ask_codey(self, text: str) -> str: + """ + Generates code in response to code-related questions ('WRITE_CODE' intent). + + Provides instructions to the chat model, handles potential errors (e.g., unclear questions, unsupported programming languages), and formats the generated code. + + Args: + text (str): The user's code-related query. + + Returns: + str: The generated code (or an error message if applicable). + """ + + unable_to_understand_question = self.config["error_msg"][ + "unable_to_understand_question" + ] + non_programming_question_error_msg = self.config["error_msg"][ + "non_programming_question_error_msg" + ] + enabled_programming_language = self.config["default"][ + "enabled_programming_language" + ] + default_language = self.config["default"]["default_language"] + chat = self.model.start_chat() + response = chat.send_message( + f""" + You are genai Programming Language Learning Assistant. + Your task is to undersand the question and write a code for same. + + Instructions: + 1. If programming language is not mentioned, then use {default_language} as default programming language to write a code. + 2. Strictly follow the instructions mentioned in the question. + 3. If the question is not clear then you can answer "{unable_to_understand_question}" + 4. Strictly answer the question if only {enabled_programming_language} is mentioned in question. + + If the question is about other programming language then DO NOT provide any answer, just say "{non_programming_question_error_msg}" + + """ + ) + + response = chat.send_message( + f"""{text}""", generation_config=self.genai_chat_parameters + ) + if response.to_dict()["candidates"][0]["finish_reason"] != 1: + self.logger.info( + "ask_codey: No response from QnA due to LLM safety checks." + ) + self.logger.info("LLM error code: %s\n", response.raw_prediction_response) + response = response.text + response = response.replace("```", "\n\n```") + response = response.replace("```java", "``` java") + return response + + def get_programming_lanuage_from_query( + self, text: str, enabled_programming_language: List + ) -> List[str]: + """ + Extracts programming languages mentioned in a user's query. + + Args: + text (str): The user's query. + enabled_programming_language (List): List of supported languages. + + Returns: + list: A list of programming languages extracted from the query (potentially empty). + """ + + response = self.model.generate_content( + f""" + You are strict programming languages extractor. + + Instructions: + 1. Extract only programming languages from message. + 2. Don't return any other languages other than programming. + 3. return [] if no programming lanuage in mentioned in message. + 4. {enabled_programming_language} these are the programming languages. + + Examples: + write a code for fibonacci series in C++? : ["C++"] + write a code using C# to generate palindrome series : ["C#"] + using python, write a sample application code to create endpoint? : ["Python"] + what are classes in Java? : ["Java"] + write a code for reverse a string : [] + what are data types? : [] + + What are the programming languages mentioned in below message? + MESSAGE:{text} + programming languages:""", # pylint: disable=C0301:line-too-long + generation_config=self.genai_qna_parameters, + ) + + programming_lang = response.text + program_lang_in_query = [] + if programming_lang: + try: + programming_lang = programming_lang.replace("'", '"') + programming_lang = json.loads(programming_lang) + program_lang_in_query = [ + x.lower().replace(" ", "").strip() for x in programming_lang + ] + except Exception: # pylint: disable=W0718,W0703 + self.logger.info("Error while extracting programming language.") + return program_lang_in_query + + def check_programming_language_in_query( + self, text: str + ) -> tuple[List[str], set[str]]: + """ + Identifies supported programming languages mentioned in a user's query. + + Args: + text (str): The user's query. + + Returns: + tuple: + * list: All programming languages found in the query. + * set: Programming languages in the query that are supported by the assistant. + """ + + enabled_programming_language = self.config["default"][ + "enabled_programming_language" + ] + enabled_programming_language_list = enabled_programming_language.split(",") + enabled_programming_language_list = [ + x.lower().replace(" ", "").strip() + for x in enabled_programming_language_list + ] + program_lang_in_query = self.get_programming_lanuage_from_query( + text, enabled_programming_language_list + ) + allowed_language_in_query = set(enabled_programming_language_list).intersection( + set(program_lang_in_query) + ) + + return program_lang_in_query, allowed_language_in_query + + def classify_intent(self, text: str, session_state: str) -> tuple[str, str]: + """ + Orchestrates intent classification, response generation, and error handling. + + Handles the following intents: + * 'WELCOME' + * 'WRITE_CODE' + * 'PROGRAMMING_QUESTION_AND_ANSWER' + * 'FOLLOWUP' + * 'CLOSE' + * 'OTHER' + + Args: + text (str): User's message. + session_state (str): Unique ID for tracking the conversation. + + Returns: + tuple: + * str: Response to the user's message. + * str: Classified intent. + """ + + try: + response = "" + + intent = self.genai_classify_intent(text) + self.logger.info("Classified intent: %s", intent) + + if intent == "WELCOME": + response = self.greetings(text) + elif intent == "WRITE_CODE": + ( + program_lang_in_query, + allowed_language_in_query, + ) = self.check_programming_language_in_query(text) + self.logger.info("program_lang_in_query: %s", program_lang_in_query) + self.logger.info( + "allowed_language_in_query: %s", allowed_language_in_query + ) + if ( + len(program_lang_in_query) > 0 + and len(allowed_language_in_query) == 0 + ): + response = self.config["error_msg"][ + "non_programming_question_error_msg" + ] + else: + response = self.ask_codey(text) + elif intent == "PROGRAMMING_QUESTION_AND_ANSWER": + if self.index_endpoint and self.deployed_index_id: + input_token_len = self.model.count_tokens(text).total_tokens + self.logger.info("Input_token_len for QnA: %s", input_token_len) + + qna_answer = self.genai_qna.ask_qna(text) + + if qna_answer: + response = qna_answer + else: + self.logger.info("Asking codey when no answer from QnA") + response = self.ask_codey(text) + else: + self.logger.info( + "Asking codey as Index or Endpoint is not available" + ) + response = self.ask_codey(text) + elif intent == "FOLLOWUP": + response = self.ask_codey(text + " based on previous message") + elif intent == "CLOSE": + response = self.closing(text) + else: + response = self.config["error_msg"]["other_intent_error_msg"] + except Exception as e: # pylint: disable=W0718,W0703,C0103 + self.logger.error("Session : %s", session_state) + self.logger.error("Error : %s", e) + + return ( + "We're sorry, but we encountered a problem. Please try again.", + "ERROR", + ) + return (response, intent) diff --git a/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/qna_using_query_routing_utils.py b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/qna_using_query_routing_utils.py new file mode 100644 index 000000000..f9a5c3fe9 --- /dev/null +++ b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/qna_using_query_routing_utils.py @@ -0,0 +1,163 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility file with more generic fuctions""" + +import configparser +from datetime import datetime + +# Utils +import os +from typing import List + +from google.cloud import aiplatform +from langchain.document_loaders import TextLoader, UnstructuredPDFLoader +from langchain.text_splitter import CharacterTextSplitter +from langchain_core.documents.base import Document + +config_file = "config.ini" +config = configparser.ConfigParser() +config.read(config_file) + + +def get_deployed_index_id( + me_index_name: str, location: str +) -> tuple[aiplatform.MatchingEngineIndexEndpoint, str]: + """ + Retrieves the deployed index ID and index endpoint from Vector Search. + + Checks if a vector search index with the specified name is already deployed in the given location. + + Args: + me_index_name (str): The name of the Vector Search index. + location (str): The location where the index is deployed. + + Returns: + tuple: A tuple containing: + * aiplatform.MatchingEngineIndexEndpoint: The index endpoint object, or None if not found. + * str: The deployed index ID, or None if not found. + """ + + index_endpoint_name = me_index_name + "-endpoint" + + list_endpoints = aiplatform.MatchingEngineIndexEndpoint.list( + filter=f"display_name={index_endpoint_name}", location=location + ) + if list_endpoints: + index_endpoint_resource_name = list_endpoints[0].resource_name + + index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=index_endpoint_resource_name + ) + if index_endpoint.deployed_indexes: + deployed_index_id = index_endpoint.deployed_indexes[0].id + else: + deployed_index_id = None + print( + f"Index endpoint resource is not available for: {index_endpoint_resource_name}" + ) + else: + index_endpoint = deployed_index_id = None + print(f"Index endpoint is not available for: {index_endpoint_name}") + + return index_endpoint, deployed_index_id + + +def get_split_documents(index_path: str) -> List[Document]: + """ + Loads documents from a folder and splits them into manageable chunks. + + Supports both PDF and plain text documents. + + Args: + index_path (str): The path to the folder containing the document(s). + + Returns: + List['str']: A list of strings, representing the chunked document(s). + """ + + split_docs = [] + + if index_path[-1] != "/": + index_path = index_path + "/" + if index_path == "": + index_path = "." + + for file_name in os.listdir(index_path): + print(f"Chunking input file: {file_name}") + if file_name.endswith(".pdf"): + loader = UnstructuredPDFLoader(index_path + file_name) + else: + loader = TextLoader(index_path + file_name) + + text_splitter = CharacterTextSplitter( + chunk_size=int(config["vector_search"]["chunk_size"]), + chunk_overlap=int(config["vector_search"]["chunk_overlap"]), + ) + split_docs.extend(text_splitter.split_documents(loader.load())) + + return split_docs + + +def create_vector_search_index( + bucket_uri: str, +) -> tuple[aiplatform.MatchingEngineIndex, aiplatform.MatchingEngineIndexEndpoint, str]: + """ + Creates a Vertex AI Matching Engine index, endpoint, and deploys the index. + + Args: + bucket_uri (str): The Cloud Storage bucket URI where embedding data is stored. + + Returns: + tuple: A tuple containing: + * aiplatform.MatchingEngineIndex: The created index object. + * aiplatform.MatchingEngineIndexEndpoint: The created index endpoint object. + * str: The ID of the deployed index. + """ + + print("Creating new vector search index..") + + UID = datetime.now().strftime("%m%d%H%M") + + # create index + my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index( + display_name=config["vector_search"]["me_index_name"], + location=config["vector_search"]["me_region"], + contents_delta_uri=bucket_uri, + dimensions=int(config["vector_search"]["me_dimensions"]), + approximate_neighbors_count=int( + config["genai_qna"]["number_of_references_to_summarise"] + ), + distance_measure_type="DOT_PRODUCT_DISTANCE", + ) + print(f"Created new index : {my_index.display_name} with ID: {my_index.name}") + + # create IndexEndpoint + index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create( + display_name=f'{config["vector_search"]["me_index_name"]}-endpoint', + public_endpoint_enabled=True, + ) + print( + f"Created new index endpoint : {index_endpoint.display_name} with ID: {index_endpoint.name}" + ) + + deployed_index_id = ( + f'{config["vector_search"]["me_index_name"].replace("-", "_")}_deployed_{UID}' + ) + + # deploy the Index to the Index Endpoint + index_endpoint.deploy_index(index=my_index, deployed_index_id=deployed_index_id) + print(f"Deployed index to endpoint : {deployed_index_id}") + + return my_index, index_endpoint, deployed_index_id diff --git a/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/qna_vector_search.py b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/qna_vector_search.py new file mode 100644 index 000000000..4d8cc2ed5 --- /dev/null +++ b/gemini/use-cases/retrieval-augmented-generation/qna_using_query_routing/utils/qna_vector_search.py @@ -0,0 +1,144 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Answer QnA Type Questions using genai Content""" + +# Utils +import configparser +import logging + +from google.cloud import aiplatform +import pandas as pd +from vertexai.generative_models import GenerativeModel +from vertexai.language_models import TextEmbeddingModel + + +class QnAVectorSearch: + """genai Generate Answer From genai Content""" + + def __init__( + self, + model: GenerativeModel, + index_endpoint: aiplatform.MatchingEngineIndexEndpoint, + deployed_index_id: str, + config_file: str, + logger=logging.getLogger(), + ) -> None: + self.config = configparser.ConfigParser() + self.config.read(config_file) + + self.logger = logger + self.model = model + self.index_endpoint = index_endpoint + self.deployed_index_id = deployed_index_id + + # Initalizing embedding model + self.text_embedding_model = TextEmbeddingModel.from_pretrained( + self.config["vector_search"]["embedding_model_name"] + ) + + self.embedding_df = pd.read_csv( + self.config["vector_search"]["embedding_csv_file"] + ) + + self.num_neighbors = int( + self.config["genai_qna"]["number_of_references_to_summarise"] + ) + + # Default retrieval prompt template + self.prompt_template = """You are a programming language learning assistant, helping the students answer their questions based on the following context. Explain the answer in detail for students. + + Instructions: + 1. Think step-by-step and then answer. + 2. Explain the answer in detail. + 3. If the answer to the question cannot be determined from the context alone, say "I cannot determine the answer to that." + 4. If the context is empty, just say "I could not find any references that are directly related to your question." + + Context: + ============= + {context} + ============= + + What is the Detailed explanation of the answer to the following question? + Question: {question} + Detailed explanation of Answer:""" # pylint: disable=line-too-long + + def find_relevant_context(self, query: str) -> str: + """ + Searches the vector index to retrieve relevant context based on a query embedding. + + Args: + query (str, optional): The query text. + + Returns: + str: The concatenated text of relevant documents found in the index. + """ + + # Generate the embeddings for user question + vector = self.text_embedding_model.get_embeddings([query]) + + queries = [vector[0].values] + + response = self.index_endpoint.find_neighbors( + deployed_index_id=self.deployed_index_id, + queries=queries, + num_neighbors=self.num_neighbors, + ) + + context = "" + for neighbor_index in range(len(response[0])): + context = ( + context + + self.embedding_df[ + (self.embedding_df["id"] == response[0][neighbor_index].id) + | (self.embedding_df["id"] == int(response[0][neighbor_index].id)) + ].text.values[0] + + " \n" + ) + + return context + + def ask_qna(self, question: str) -> str: + """Retrieves relevant context using vector search and generates an answer using a QnA model. + Args: + question (str): The user's question. + + Returns: + str: The generated answer from the QnA model, or None if no valid answer could be determined. + """ + + # Read context from relavent documents + self.logger.info("QnA: question: %s", question) + + context = self.find_relevant_context(question) + # self.logger.info("QnA: context: %s", context) + + # Get response + if len(context) > 0: + response = self.model.generate_content( + self.prompt_template.format(context=context, question=question) + ) + + if response and int(response.candidates[0].finish_reason) == 1: + answer = response.text + self.logger.info("QnA: response: %s", answer) + if ( + "I cannot determine the answer to that." in answer + or "I could not find any references that are directly related to your question." + in answer + ): + return "" + else: + return answer + return ""