From 0574da54c58c01352311d1ac03692d59e94e606b Mon Sep 17 00:00:00 2001 From: Shubhadeep Das Date: Thu, 9 Nov 2023 23:07:14 +0530 Subject: [PATCH] Add initial files for the repository --- .pre-commit-config.yaml | 11 + RetrievalAugmentedGeneration/.gitattributes | 1 + RetrievalAugmentedGeneration/.gitignore | 25 + RetrievalAugmentedGeneration/Dockerfile | 11 + .../Dockerfile.notebooks | 28 + RetrievalAugmentedGeneration/README.md | 156 + .../chain_server/__init__.py | 16 + .../chain_server/chains.py | 90 + .../chain_server/configuration.py | 167 + .../chain_server/configuration_wizard.py | 411 ++ .../chain_server/openapi_schema.json | 244 + .../chain_server/server.py | 101 + .../chain_server/trt_llm.py | 557 ++ .../chain_server/utils.py | 154 + .../deploy/compose.env | 23 + .../deploy/config.yaml | 52 + .../deploy/docker-compose.yaml | 174 + .../docs/architecture.md | 89 + .../docs/chat_server.md | 182 + .../docs/configuration.md | 74 + RetrievalAugmentedGeneration/docs/frontend.md | 33 + .../docs/jupyter_server.md | 40 + .../docs/llm_inference_server.md | 31 + .../docs/support_matrix.md | 30 + .../frontend/Dockerfile | 12 + .../frontend/README.md | 38 + .../frontend/frontend/__init__.py | 94 + .../frontend/frontend/__main__.py | 122 + .../frontend/frontend/api.py | 72 + .../frontend/frontend/assets/__init__.py | 38 + .../frontend/frontend/assets/kaizen-theme.css | 13 + .../frontend/assets/kaizen-theme.json | 336 ++ .../frontend/frontend/chat_client.py | 95 + .../frontend/frontend/configuration.py | 44 + .../frontend/frontend/configuration_wizard.py | 411 ++ .../frontend/frontend/pages/__init__.py | 19 + .../frontend/frontend/pages/converse.py | 127 + .../frontend/frontend/pages/kb.py | 56 + .../frontend/frontend/static/404.html | 1 + .../WuNGAl0x4o1D5HqLxhHMt/_buildManifest.js | 1 + .../WuNGAl0x4o1D5HqLxhHMt/_ssgManifest.js | 1 + .../static/chunks/78-a36dca5d49fafb86.js | 1 + .../chunks/framework-7a7e500878b44665.js | 33 + .../static/chunks/main-92011a1a7f336a6f.js | 1 + .../chunks/pages/_app-f21c0780e30f5eb6.js | 5 + .../chunks/pages/_app-f55c3b932a623280.js | 5 + .../chunks/pages/_error-54de1933a164a1ff.js | 1 + .../chunks/pages/converse-39686323b565eff0.js | 1 + .../chunks/pages/converse-61880f01babd873a.js | 1 + .../chunks/pages/index-1a1d31dae38463f7.js | 1 + .../chunks/pages/index-6a3f286eb0986c10.js | 1 + .../chunks/pages/kb-cf0d102293dc0a74.js | 1 + .../chunks/pages/tuning-0b7bb1111c2d2a56.js | 1 + .../chunks/polyfills-78c92fac7aa8fdd8.js | 1 + .../static/chunks/webpack-5146130448d8adf7.js | 1 + .../_next/static/css/7636246223312442.css | 1 + .../_next/static/css/98b512633409f7e1.css | 1 + .../s7oUSppGTRWsY8BXJmxYB/_buildManifest.js | 1 + .../s7oUSppGTRWsY8BXJmxYB/_ssgManifest.js | 1 + .../frontend/frontend/static/converse.html | 1 + .../frontend/frontend/static/favicon.ico | Bin 0 -> 25931 bytes .../frontend/frontend/static/index.html | 1 + .../frontend/frontend/static/kb.html | 1 + .../frontend/frontend/static/next.svg | 1 + .../frontend/frontend/static/vercel.svg | 1 + .../frontend/frontend_js/.eslintrc.json | 3 + .../frontend/frontend_js/.gitignore | 36 + .../frontend/frontend_js/.npmrc | 4 + .../frontend/frontend_js/README.md | 38 + .../frontend/frontend_js/next.config.js | 7 + .../frontend/frontend_js/package-lock.json | 4751 +++++++++++++++++ .../frontend/frontend_js/package.json | 32 + .../frontend/frontend_js/public/favicon.ico | Bin 0 -> 25931 bytes .../frontend/frontend_js/public/next.svg | 1 + .../frontend/frontend_js/public/vercel.svg | 1 + .../frontend_js/src/components/AppShell.tsx | 76 + .../src/components/GradioPortal.tsx | 29 + .../frontend_js/src/components/Navbar.tsx | 47 + .../frontend_js/src/components/NavbarLink.tsx | 19 + .../src/components/StoreProvider.tsx | 82 + .../frontend/frontend_js/src/pages/_app.tsx | 24 + .../frontend_js/src/pages/_document.tsx | 13 + .../frontend_js/src/pages/converse.tsx | 18 + .../frontend/frontend_js/src/pages/index.tsx | 15 + .../frontend/frontend_js/src/pages/kb.tsx | 18 + .../frontend_js/src/styles/globals.css | 9 + .../src/styles/gradio-embed.module.css | 13 + .../frontend/frontend_js/tsconfig.json | 23 + .../frontend/requirements.txt | 8 + .../images/docker-output.png | Bin 0 -> 343710 bytes .../images/image0.png | Bin 0 -> 187570 bytes .../images/image1.png | Bin 0 -> 165091 bytes .../images/image2.png | Bin 0 -> 394513 bytes .../images/image3.jpg | Bin 0 -> 77746 bytes .../images/image4.jpg | Bin 0 -> 263344 bytes .../notebooks/01-llm-streaming-client.ipynb | 181 + .../notebooks/02_langchain_simple.ipynb | 378 ++ .../notebooks/03_llama_index_simple.ipynb | 461 ++ .../04_llamaindex_hier_node_parser.ipynb | 466 ++ .../notebooks/05_dataloader.ipynb | 181 + .../notebooks/dataset.zip | 3 + .../imgs/data_connection_langchain.jpeg | Bin 0 -> 1028460 bytes .../notebooks/imgs/llama_hub.png | Bin 0 -> 546602 bytes .../notebooks/imgs/vector_stores.jpeg | Bin 0 -> 878263 bytes .../notebooks/requirements.txt | 14 + RetrievalAugmentedGeneration/requirements.txt | 12 + 106 files changed, 11206 insertions(+) create mode 100644 .pre-commit-config.yaml create mode 100644 RetrievalAugmentedGeneration/.gitattributes create mode 100644 RetrievalAugmentedGeneration/.gitignore create mode 100644 RetrievalAugmentedGeneration/Dockerfile create mode 100644 RetrievalAugmentedGeneration/Dockerfile.notebooks create mode 100644 RetrievalAugmentedGeneration/README.md create mode 100644 RetrievalAugmentedGeneration/chain_server/__init__.py create mode 100644 RetrievalAugmentedGeneration/chain_server/chains.py create mode 100644 RetrievalAugmentedGeneration/chain_server/configuration.py create mode 100644 RetrievalAugmentedGeneration/chain_server/configuration_wizard.py create mode 100644 RetrievalAugmentedGeneration/chain_server/openapi_schema.json create mode 100644 RetrievalAugmentedGeneration/chain_server/server.py create mode 100644 RetrievalAugmentedGeneration/chain_server/trt_llm.py create mode 100644 RetrievalAugmentedGeneration/chain_server/utils.py create mode 100644 RetrievalAugmentedGeneration/deploy/compose.env create mode 100644 RetrievalAugmentedGeneration/deploy/config.yaml create mode 100644 RetrievalAugmentedGeneration/deploy/docker-compose.yaml create mode 100644 RetrievalAugmentedGeneration/docs/architecture.md create mode 100644 RetrievalAugmentedGeneration/docs/chat_server.md create mode 100644 RetrievalAugmentedGeneration/docs/configuration.md create mode 100644 RetrievalAugmentedGeneration/docs/frontend.md create mode 100644 RetrievalAugmentedGeneration/docs/jupyter_server.md create mode 100644 RetrievalAugmentedGeneration/docs/llm_inference_server.md create mode 100644 RetrievalAugmentedGeneration/docs/support_matrix.md create mode 100644 RetrievalAugmentedGeneration/frontend/Dockerfile create mode 100644 RetrievalAugmentedGeneration/frontend/README.md create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/__init__.py create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/__main__.py create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/api.py create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/assets/__init__.py create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/assets/kaizen-theme.css create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/assets/kaizen-theme.json create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/chat_client.py create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/configuration.py create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/configuration_wizard.py create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/pages/__init__.py create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/pages/converse.py create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/pages/kb.py create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/404.html create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/WuNGAl0x4o1D5HqLxhHMt/_buildManifest.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/WuNGAl0x4o1D5HqLxhHMt/_ssgManifest.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/78-a36dca5d49fafb86.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/framework-7a7e500878b44665.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/main-92011a1a7f336a6f.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/pages/_app-f21c0780e30f5eb6.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/pages/_app-f55c3b932a623280.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/pages/_error-54de1933a164a1ff.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/pages/converse-39686323b565eff0.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/pages/converse-61880f01babd873a.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/pages/index-1a1d31dae38463f7.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/pages/index-6a3f286eb0986c10.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/pages/kb-cf0d102293dc0a74.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/pages/tuning-0b7bb1111c2d2a56.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/polyfills-78c92fac7aa8fdd8.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/webpack-5146130448d8adf7.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/css/7636246223312442.css create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/css/98b512633409f7e1.css create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/s7oUSppGTRWsY8BXJmxYB/_buildManifest.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/s7oUSppGTRWsY8BXJmxYB/_ssgManifest.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/converse.html create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/favicon.ico create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/index.html create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/kb.html create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/next.svg create mode 100644 RetrievalAugmentedGeneration/frontend/frontend/static/vercel.svg create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/.eslintrc.json create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/.gitignore create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/.npmrc create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/README.md create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/next.config.js create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/package-lock.json create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/package.json create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/public/favicon.ico create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/public/next.svg create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/public/vercel.svg create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/components/AppShell.tsx create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/components/GradioPortal.tsx create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/components/Navbar.tsx create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/components/NavbarLink.tsx create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/components/StoreProvider.tsx create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/pages/_app.tsx create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/pages/_document.tsx create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/pages/converse.tsx create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/pages/index.tsx create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/pages/kb.tsx create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/styles/globals.css create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/src/styles/gradio-embed.module.css create mode 100644 RetrievalAugmentedGeneration/frontend/frontend_js/tsconfig.json create mode 100644 RetrievalAugmentedGeneration/frontend/requirements.txt create mode 100644 RetrievalAugmentedGeneration/images/docker-output.png create mode 100644 RetrievalAugmentedGeneration/images/image0.png create mode 100644 RetrievalAugmentedGeneration/images/image1.png create mode 100644 RetrievalAugmentedGeneration/images/image2.png create mode 100644 RetrievalAugmentedGeneration/images/image3.jpg create mode 100644 RetrievalAugmentedGeneration/images/image4.jpg create mode 100644 RetrievalAugmentedGeneration/notebooks/01-llm-streaming-client.ipynb create mode 100644 RetrievalAugmentedGeneration/notebooks/02_langchain_simple.ipynb create mode 100644 RetrievalAugmentedGeneration/notebooks/03_llama_index_simple.ipynb create mode 100644 RetrievalAugmentedGeneration/notebooks/04_llamaindex_hier_node_parser.ipynb create mode 100644 RetrievalAugmentedGeneration/notebooks/05_dataloader.ipynb create mode 100644 RetrievalAugmentedGeneration/notebooks/dataset.zip create mode 100644 RetrievalAugmentedGeneration/notebooks/imgs/data_connection_langchain.jpeg create mode 100644 RetrievalAugmentedGeneration/notebooks/imgs/llama_hub.png create mode 100644 RetrievalAugmentedGeneration/notebooks/imgs/vector_stores.jpeg create mode 100644 RetrievalAugmentedGeneration/notebooks/requirements.txt create mode 100644 RetrievalAugmentedGeneration/requirements.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..22aece8e7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: + - repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.4.2 + hooks: + - id: insert-license + files: ^RetrievalAugmentedGeneration/ + exclude: ^RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/|^RetrievalAugmentedGeneration/llm-inference-server/ensemble_models + types: [python] + args: + - --license-filepath + - RetrievalAugmentedGeneration/LICENSE.md diff --git a/RetrievalAugmentedGeneration/.gitattributes b/RetrievalAugmentedGeneration/.gitattributes new file mode 100644 index 000000000..c8a8d73b2 --- /dev/null +++ b/RetrievalAugmentedGeneration/.gitattributes @@ -0,0 +1 @@ +notebooks/dataset.zip filter=lfs diff=lfs merge=lfs -text diff --git a/RetrievalAugmentedGeneration/.gitignore b/RetrievalAugmentedGeneration/.gitignore new file mode 100644 index 000000000..baec55148 --- /dev/null +++ b/RetrievalAugmentedGeneration/.gitignore @@ -0,0 +1,25 @@ +# Python Exclusions +.venv +__pycache__ + +# Sphinx Exclusions +_build + +# Helm Exclusions +**/charts/*.tgz + +# project temp files +deploy/*.log +deploy/*.txt +**/my.* +**/my-* + +# Next JS Exclusions +**/.next +frontend/frontend_js/out +frontend-sdxl/frontend_js/out +**/node_modules + +# Docker Compose exclusions +volumes/ +uploaded_files/ diff --git a/RetrievalAugmentedGeneration/Dockerfile b/RetrievalAugmentedGeneration/Dockerfile new file mode 100644 index 000000000..5d71ee7ab --- /dev/null +++ b/RetrievalAugmentedGeneration/Dockerfile @@ -0,0 +1,11 @@ +ARG BASE_IMAGE_URL=nvcr.io/nvidia/pytorch +ARG BASE_IMAGE_TAG=23.08-py3 + + +FROM ${BASE_IMAGE_URL}:${BASE_IMAGE_TAG} +COPY chain_server /opt/chain_server +RUN --mount=type=bind,source=requirements.txt,target=/opt/requirements.txt \ + python3 -m pip install --no-cache-dir -r /opt/requirements.txt + +WORKDIR /opt +ENTRYPOINT ["uvicorn", "chain_server.server:app"] diff --git a/RetrievalAugmentedGeneration/Dockerfile.notebooks b/RetrievalAugmentedGeneration/Dockerfile.notebooks new file mode 100644 index 000000000..71e0f04f3 --- /dev/null +++ b/RetrievalAugmentedGeneration/Dockerfile.notebooks @@ -0,0 +1,28 @@ +# Use a base image with Python +FROM python:3.10-slim + +# Set working directory +WORKDIR /app + +#COPY notebooks +COPY notebooks/*.ipynb . + +RUN mkdir -p /app/imgs + +COPY notebooks/dataset.zip . + +COPY notebooks/imgs/* imgs/ + +COPY chain_server/trt_llm.py . + +COPY notebooks/requirements.txt . +#Run pip dependencies +RUN pip3 install -r requirements.txt + +RUN apt-get update && apt-get install -y unzip wget git libgl1-mesa-glx libglib2.0-0 + +# Expose port 8888 for JupyterLab +EXPOSE 8888 + +# Start JupyterLab when the container runs +CMD ["jupyter", "lab", "--allow-root", "--ip=0.0.0.0","--NotebookApp.token=''", "--port=8888"] diff --git a/RetrievalAugmentedGeneration/README.md b/RetrievalAugmentedGeneration/README.md new file mode 100644 index 000000000..aaa583ad4 --- /dev/null +++ b/RetrievalAugmentedGeneration/README.md @@ -0,0 +1,156 @@ +# Retrieval Augmented Generation + +## Project Details +**Project Goal**: An external reference for a chatbot to question answer off public press releases & tech blogs. Performs document ingestion & Q&A interface using best open models in any cloud or customer datacenter, leverages the power of GPU-accelerated Milvus for efficient vector storage and retrieval, along with TRT-LLM, to achieve lightning-fast inference speeds with custom LangChain LLM wrapper. + +## Components +- **LLM**: Llama2 -- 7b, 13b, and 70b all supported. 13b and 70b generate good responses. Wanted best open-source model available at the time of creation. +- **LLM Backend**: TRT-LLM for speed. +- **Vector DB**: Milvus because it's GPU accelerated. +- **Embedding Model**: e5-large-v2 since it appeared to be one of the best embedding model available at the moment. +- **Framework(s)**: LangChain and LlamaIndex. + +This reference workflow uses a variety of components and services to customize and deploy the RAG based chatbot. The following diagram illustrates how they work together. Refer to the [detailed architecture guide](./docs/architecture.md) to understand more about these components and how they are tied together. + + +![Diagram](./../RetrievalAugmentedGeneration/images/image3.jpg) + + +# Getting Started +This section covers step by step guide to setup and try out this example workflow. + +## Prerequisites +Before proceeding with this guide, make sure you meet the following prerequisites: + +- You should have at least one NVIDIA GPU. For this guide, we used an A100 data center GPU. + + - NVIDIA driver version 535 or newer. To check the driver version run: ``nvidia-smi --query-gpu=driver_version --format=csv,noheader``. + - If you are running multiple GPUs they must all be set to the same mode (ie Compute vs. Display). You can check compute mode for each GPU using + ``nvidia-smi -q -d compute`` + +### Setup the following + +- Docker and Docker-Compose are essential. Please follow the [installation instructions](https://docs.docker.com/engine/install/ubuntu/). + + Note: + Please do **not** use Docker that is packaged with Ubuntu as the newer version of Docker is required for proper Docker Compose support. + + Make sure your user account is able to execute Docker commands. + + +- NVIDIA Container Toolkit is also required. Refer to the [installation instructions](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). + + +- NGC Account and API Key + + - Please refer to [instructions](https://docs.nvidia.com/ngc/gpu-cloud/ngc-overview/index.html) to create account and generate NGC API key. + - Docker login to `nvcr.io` using the following command: + ``` + docker login nvcr.io + ``` + +- You can download Llama2 Chat Model Weights from [Meta](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) or [HuggingFace](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf/). + + **Note for checkpoint downloaded using Meta**: + + When downloading model weights from Meta, you can follow the instructions up to the point of downloading the models using ``download.sh``. There is no need to deploy the model using the steps mentioned in the repository. We will use Triton to deploy the model. + + Meta will download two additional files, namely tokenizer.model and tokenizer_checklist.chk, outside of the model checkpoint directory. Ensure that you copy these files into the same directory as the model checkpoint directory. + + + **Note**: + + In this workflow, we will be leveraging a Llama2 (13B parameters) chat model, which requires 50 GB of GPU memory. If you prefer to leverage 7B parameter model, this will require 38GB memory. The 70B parameter model initially requires 240GB memory. + IMPORTANT: For this initial version of the workflow, an A100 GPU is supported. + + +## Install Guide +### Step 1: Move to deploy directory + cd deploy + +### Step 2: Set Environment Variables + +Modify ``compose.env`` in the ``deploy`` directory to set your environment variables. The following variables are required. + + # full path to the local copy of the model weights + export MODEL_DIRECTORY="$HOME/src/Llama-2-13b-chat-hf" + + # the architecture of the model. eg: llama + export MODEL_ARCHITECTURE="llama" + + # the name of the model being used - only for displaying on frontend + export MODEL_NAME="llama-2-13b-chat" + + # [OPTIONAL] the config file for chain server + APP_CONFIG_FILE=/dev/null + + +### Step 3: Build and Start Containers +- Pull lfs files. This will pull large files from repository. + ``` + git lfs pull + ``` +- Run the following command to build containers. + ``` + source compose.env; docker compose build + ``` + +- Run the following command to start containers. + ``` + source compose.env; docker compose up -d + ``` + > ⚠️ **NOTE**: It will take a few minutes for the containers to come up and may take up to 5 minutes for the Triton server to be ready. Adding the `-d` flag will have the services run in the background. ⚠️ + +- Run ``docker ps -a``. When the containers are ready the output should look similar to the image below. + ![Docker Output](./images/docker-output.png "Docker Output Image") + +### Step 4: Experiment with RAG in JupyterLab + +This AI Workflow includes Jupyter notebooks which allow you to experiment with RAG. + +- Using a web browser, type in the following URL to open Jupyter + + ``http://host-ip:8888`` + +- Locate the LLM Streaming Client notebook ``01-llm-streaming-client.ipynb`` which demonstrates how to stream responses from the LLM. + +- Proceed with the next 4 notebooks: + + - [Document Question-Answering with LangChain](../notebooks/02_langchain_simple.ipynb) + + - [Document Question-Answering with LlamaIndex](../notebooks/03_llama_index_simple.ipynb) + + - [Advanced Document Question-Answering with LlamaIndex](../notebooks/04_llamaindex_hier_node_parser.ipynb) + + - [Interact with REST FastAPI Server](../notebooks/05_dataloader.ipynb) + +### Step 5: Run the Sample Web Application +A sample chatbot web application is provided in the workflow. Requests to the chat system are wrapped in FastAPI calls. + +- Open the web application at ``http://host-ip:8090``. + +- Type in the following question without using a knowledge base: "How many cores are on the Nvidia Grace superchip?" + + **Note:** the chatbot mentions the chip doesn't exist. + +- To use a knowledge base: + + - Click the **Knowledge Base** tab and upload the file [dataset.zip](./RetrievalAugmentedGeneration/notebook/dataset.zip). + +- Return to **Converse** tab and check **[X] Use knowledge base**. + +- Retype the question: "How many cores are on the Nvidia Grace superchip?" + + +# Learn More +1. [Architecture Guide](./docs/architecture.md): Detailed explanation of different components and how they are tried up together. +2. Component Guides: Component specific features are enlisted in these sections. + 1. [Chain Server](./docs/chat_server.md) + 2. [NeMo Framework Inference Server](./docs/llm_inference_server.md) + 3. [Jupyter Server](./docs/jupyter_server.md) + 4. [Sample frontend](./docs/frontend.md) +3. [Configuration Guide](./docs/configuration.md): This guide covers different configurations available for this workflow. +4. [Support Matrix](./docs/support_matrix.md): This covers GPU, CPU, Memory and Storage requirements for deploying this workflow. + +# Known Issues +- Uploading a file with size more than 10 MB may fail due to preset timeouts during the ingestion process. diff --git a/RetrievalAugmentedGeneration/chain_server/__init__.py b/RetrievalAugmentedGeneration/chain_server/__init__.py new file mode 100644 index 000000000..513d86dd2 --- /dev/null +++ b/RetrievalAugmentedGeneration/chain_server/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A microservice for hosting Langchain Chains.""" diff --git a/RetrievalAugmentedGeneration/chain_server/chains.py b/RetrievalAugmentedGeneration/chain_server/chains.py new file mode 100644 index 000000000..e790fb545 --- /dev/null +++ b/RetrievalAugmentedGeneration/chain_server/chains.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LLM Chains for executing Retrival Augmented Generation.""" +import base64 +from pathlib import Path +from typing import Generator + +from llama_index import Prompt, download_loader +from llama_index.node_parser import SimpleNodeParser +from llama_index.query_engine import RetrieverQueryEngine +from llama_index.response.schema import StreamingResponse + +from chain_server.utils import ( + LimitRetrievedNodesLength, + get_config, + get_doc_retriever, + get_llm, + get_text_splitter, + get_vector_index, + is_base64_encoded, + set_service_context, +) + + +def llm_chain( + context: str, question: str, num_tokens: int +) -> Generator[str, None, None]: + """Execute a simple LLM chain using the components defined above.""" + set_service_context() + prompt = get_config().prompts.chat_template.format( + context_str=context, query_str=question + ) + response = get_llm().stream_complete(prompt, tokens=num_tokens) + gen_response = (resp.delta for resp in response) + return gen_response + + +def rag_chain(prompt: str, num_tokens: int) -> Generator[str, None, None]: + """Execute a Retrieval Augmented Generation chain using the components defined above.""" + set_service_context() + get_llm().llm.tokens = num_tokens # type: ignore + retriever = get_doc_retriever(num_nodes=4) + qa_template = Prompt(get_config().prompts.rag_template) + query_engine = RetrieverQueryEngine.from_args( + retriever, + text_qa_template=qa_template, + node_postprocessors=[LimitRetrievedNodesLength()], + streaming=True, + ) + response = query_engine.query(prompt) + + # Properly handle an empty response + if isinstance(response, StreamingResponse): + return response.response_gen + return StreamingResponse(iter([])).response_gen # type: ignore + + +def ingest_docs(data_dir: str, filename: str) -> None: + """Ingest documents to the VectorDB.""" + unstruct_reader = download_loader("UnstructuredReader") + loader = unstruct_reader() + documents = loader.load_data(file=Path(data_dir), split_documents=False) + + encoded_filename = filename[:-4] + if not is_base64_encoded(encoded_filename): + encoded_filename = base64.b64encode(encoded_filename.encode("utf-8")).decode( + "utf-8" + ) + + for document in documents: + document.metadata = {"filename": encoded_filename} + + index = get_vector_index() + text_splitter = get_text_splitter() + node_parser = SimpleNodeParser.from_defaults(text_splitter=text_splitter) + nodes = node_parser.get_nodes_from_documents(documents) + index.insert_nodes(nodes) diff --git a/RetrievalAugmentedGeneration/chain_server/configuration.py b/RetrievalAugmentedGeneration/chain_server/configuration.py new file mode 100644 index 000000000..ed35442c9 --- /dev/null +++ b/RetrievalAugmentedGeneration/chain_server/configuration.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The definition of the application configuration.""" +from chain_server.configuration_wizard import ConfigWizard, configclass, configfield + + +@configclass +class MilvusConfig(ConfigWizard): + """Configuration class for the Weaviate connection. + + :cvar url: URL of Milvus DB + """ + + url: str = configfield( + "url", + default="http://localhost:19530", + help_txt="The host of the machine running Milvus DB", + ) + + +@configclass +class TritonConfig(ConfigWizard): + """Configuration class for the Triton connection. + + :cvar server_url: The location of the Triton server hosting the llm model. + :cvar model_name: The name of the hosted model. + """ + + server_url: str = configfield( + "server_url", + default="localhost:8001", + help_txt="The location of the Triton server hosting the llm model.", + ) + model_name: str = configfield( + "model_name", + default="ensemble", + help_txt="The name of the hosted model.", + ) + + +@configclass +class TextSplitterConfig(ConfigWizard): + """Configuration class for the Text Splitter. + + :cvar chunk_size: Chunk size for text splitter. + :cvar chunk_overlap: Text overlap in text splitter. + """ + + chunk_size: int = configfield( + "chunk_size", + default=510, + help_txt="Chunk size for text splitting.", + ) + chunk_overlap: int = configfield( + "chunk_overlap", + default=200, + help_txt="Overlapping text length for splitting.", + ) + + +@configclass +class EmbeddingConfig(ConfigWizard): + """Configuration class for the Embeddings. + + :cvar model_name: The name of the huggingface embedding model. + """ + + model_name: str = configfield( + "model_name", + default="intfloat/e5-large-v2", + help_txt="The name of huggingface embedding model.", + ) + + +@configclass +class PromptsConfig(ConfigWizard): + """Configuration class for the Prompts. + + :cvar chat_template: Prompt template for chat. + :cvar rag_template: Prompt template for rag. + """ + + chat_template: str = configfield( + "chat_template", + default=( + "[INST] <>" + "You are a helpful, respectful and honest assistant." + "Always answer as helpfully as possible, while being safe." + "Please ensure that your responses are positive in nature." + "<>" + "[/INST] {context_str} [INST] {query_str} [/INST]" + ), + help_txt="Prompt template for chat.", + ) + rag_template: str = configfield( + "rag_template", + default=( + "[INST] <>" + "Use the following context to answer the user's question. If you don't know the answer," + "just say that you don't know, don't try to make up an answer." + "<>" + "[INST] Context: {context_str} Question: {query_str} Only return the helpful" + " answer below and nothing else. Helpful answer:[/INST]" + ), + help_txt="Prompt template for rag.", + ) + + +@configclass +class AppConfig(ConfigWizard): + """Configuration class for the application. + + :cvar milvus: The configuration of the Milvus vector db connection. + :type milvus: MilvusConfig + :cvar triton: The configuration of the backend Triton server. + :type triton: TritonConfig + :cvar text_splitter: The configuration for text splitter + :type text_splitter: TextSplitterConfig + :cvar embeddings: The configuration for huggingface embeddings + :type embeddings: EmbeddingConfig + :cvar prompts: The Prompts template for RAG and Chat + :type prompts: PromptsConfig + """ + + milvus: MilvusConfig = configfield( + "milvus", + env=False, + help_txt="The configuration of the Milvus connection.", + default=MilvusConfig(), + ) + triton: TritonConfig = configfield( + "triton", + env=False, + help_txt="The configuration for the Triton server hosting the embedding models.", + default=TritonConfig(), + ) + text_splitter: TextSplitterConfig = configfield( + "text_splitter", + env=False, + help_txt="The configuration for text splitter.", + default=TextSplitterConfig(), + ) + embeddings: EmbeddingConfig = configfield( + "embeddings", + env=False, + help_txt="The configuration of embedding model.", + default=EmbeddingConfig(), + ) + prompts: PromptsConfig = configfield( + "prompts", + env=False, + help_txt="Prompt templates for chat and rag.", + default=PromptsConfig(), + ) diff --git a/RetrievalAugmentedGeneration/chain_server/configuration_wizard.py b/RetrievalAugmentedGeneration/chain_server/configuration_wizard.py new file mode 100644 index 000000000..d63d9e416 --- /dev/null +++ b/RetrievalAugmentedGeneration/chain_server/configuration_wizard.py @@ -0,0 +1,411 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A module containing utilities for defining application configuration. + +This module provides a configuration wizard class that can read configuration data from YAML, JSON, and environment +variables. The configuration wizard is based heavily off of the JSON and YAML wizards from the `dataclass-wizard` +Python package. That package is in-turn based heavily off of the built-in `dataclass` module. + +This module adds Environment Variable parsing to config file reading. +""" +# pylint: disable=too-many-lines; this file is meant to be portable between projects so everything is put into one file + +import json +import logging +import os +from dataclasses import _MISSING_TYPE, dataclass +from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple, Union + +import yaml +from dataclass_wizard import ( + JSONWizard, + LoadMeta, + YAMLWizard, + errors, + fromdict, + json_field, +) +from dataclass_wizard.models import JSONField +from dataclass_wizard.utils.string_conv import to_camel_case + +configclass = dataclass(frozen=True) +ENV_BASE = "APP" +_LOGGER = logging.getLogger(__name__) + + +def configfield( + name: str, *, env: bool = True, help_txt: str = "", **kwargs: Any +) -> JSONField: + """Create a data class field with the specified name in JSON format. + + :param name: The name of the field. + :type name: str + :param env: Whether this field should be configurable from an environment variable. + :type env: bool + :param help_txt: The description of this field that is used in help docs. + :type help_txt: str + :param **kwargs: Optional keyword arguments to customize the JSON field. More information here: + https://dataclass-wizard.readthedocs.io/en/latest/dataclass_wizard.html#dataclass_wizard.json_field + :type **kwargs: Any + :returns: A JSONField instance with the specified name and optional parameters. + :rtype: JSONField + + :raises TypeError: If the provided name is not a string. + """ + # sanitize specified name + if not isinstance(name, str): + raise TypeError("Provided name must be a string.") + json_name = to_camel_case(name) + + # update metadata + meta = kwargs.get("metadata", {}) + meta["env"] = env + meta["help"] = help_txt + kwargs["metadata"] = meta + + # create the data class field + field = json_field(json_name, **kwargs) + return field + + +class _Color: + """A collection of colors used when writing output to the shell.""" + + # pylint: disable=too-few-public-methods; this class does not require methods. + + PURPLE = "\033[95m" + BLUE = "\033[94m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + END = "\033[0m" + + +class ConfigWizard(JSONWizard, YAMLWizard): # type: ignore[misc] # dataclass-wizard doesn't provide stubs + """A configuration wizard class that can read configuration data from YAML, JSON, and environment variables.""" + + # pylint: disable=arguments-differ,arguments-renamed; this class intentionally reduces arguments for some methods. + + @classmethod + def print_help( + cls, + help_printer: Callable[[str], Any], + *, + env_parent: Optional[str] = None, + json_parent: Optional[Tuple[str, ...]] = None, + ) -> None: + """Print the help documentation for the application configuration with the provided `write` function. + + :param help_printer: The `write` function that will be used to output the data. + :param help_printer: Callable[[str], None] + :param env_parent: The name of the parent environment variable. Leave blank, used for recursion. + :type env_parent: Optional[str] + :param json_parent: The name of the parent JSON key. Leave blank, used for recursion. + :type json_parent: Optional[Tuple[str, ...]] + :returns: A list of tuples with one item per configuration value. Each item will have the environment variable + and a tuple to the path in configuration. + :rtype: List[Tuple[str, Tuple[str, ...]]] + """ + if not env_parent: + env_parent = "" + help_printer("---\n") + if not json_parent: + json_parent = () + + for ( + _, + val, + ) in ( + cls.__dataclass_fields__.items() # pylint: disable=no-member; false positive + ): # pylint: disable=no-member; member is added by dataclass. + jsonname = val.json.keys[0] + envname = jsonname.upper() + full_envname = f"{ENV_BASE}{env_parent}_{envname}" + is_embedded_config = hasattr(val.type, "envvars") + + # print the help data + indent = len(json_parent) * 2 + if is_embedded_config: + default = "" + elif not isinstance(val.default_factory, _MISSING_TYPE): + default = val.default_factory() + elif isinstance(val.default, _MISSING_TYPE): + default = "NO-DEFAULT-VALUE" + else: + default = val.default + help_printer( + f"{_Color.BOLD}{' ' * indent}{jsonname}:{_Color.END} {default}\n" + ) + + # print comments + if is_embedded_config: + indent += 2 + if val.metadata.get("help"): + help_printer(f"{' ' * indent}# {val.metadata['help']}\n") + if not is_embedded_config: + typestr = getattr(val.type, "__name__", None) or str(val.type).replace( + "typing.", "" + ) + help_printer(f"{' ' * indent}# Type: {typestr}\n") + if val.metadata.get("env", True): + help_printer(f"{' ' * indent}# ENV Variable: {full_envname}\n") + # if not is_embedded_config: + help_printer("\n") + + if is_embedded_config: + new_env_parent = f"{env_parent}_{envname}" + new_json_parent = json_parent + (jsonname,) + val.type.print_help( + help_printer, env_parent=new_env_parent, json_parent=new_json_parent + ) + + help_printer("\n") + + @classmethod + def envvars( + cls, + env_parent: Optional[str] = None, + json_parent: Optional[Tuple[str, ...]] = None, + ) -> List[Tuple[str, Tuple[str, ...], type]]: + """Calculate valid environment variables and their config structure location. + + :param env_parent: The name of the parent environment variable. + :type env_parent: Optional[str] + :param json_parent: The name of the parent JSON key. + :type json_parent: Optional[Tuple[str, ...]] + :returns: A list of tuples with one item per configuration value. Each item will have the environment variable, + a tuple to the path in configuration, and they type of the value. + :rtype: List[Tuple[str, Tuple[str, ...], type]] + """ + if not env_parent: + env_parent = "" + if not json_parent: + json_parent = () + output = [] + + for ( + _, + val, + ) in ( + cls.__dataclass_fields__.items() # pylint: disable=no-member; false positive + ): # pylint: disable=no-member; member is added by dataclass. + jsonname = val.json.keys[0] + envname = jsonname.upper() + full_envname = f"{ENV_BASE}{env_parent}_{envname}" + is_embedded_config = hasattr(val.type, "envvars") + + # add entry to output list + if is_embedded_config: + new_env_parent = f"{env_parent}_{envname}" + new_json_parent = json_parent + (jsonname,) + output += val.type.envvars( + env_parent=new_env_parent, json_parent=new_json_parent + ) + elif val.metadata.get("env", True): + output += [(full_envname, json_parent + (jsonname,), val.type)] + + return output + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ConfigWizard": + """Create a ConfigWizard instance from a dictionary. + + :param data: The dictionary containing the configuration data. + :type data: Dict[str, Any] + :returns: A ConfigWizard instance created from the input dictionary. + :rtype: ConfigWizard + + :raises RuntimeError: If the configuration data is not a dictionary. + """ + # sanitize data + if not data: + data = {} + if not isinstance(data, dict): + raise RuntimeError("Configuration data is not a dictionary.") + + # parse env variables + for envvar in cls.envvars(): + var_name, conf_path, var_type = envvar + var_value = os.environ.get(var_name) + if var_value: + var_value = try_json_load(var_value) + update_dict(data, conf_path, var_value) + _LOGGER.debug( + "Found EnvVar Config - %s:%s = %s", + var_name, + str(var_type), + repr(var_value), + ) + + LoadMeta(key_transform="CAMEL").bind_to(cls) + return fromdict(cls, data) # type: ignore[no-any-return] # dataclass-wizard doesn't provide stubs + + @classmethod + def from_file(cls, filepath: str) -> Optional["ConfigWizard"]: + """Load the application configuration from the specified file. + + The file must be either in JSON or YAML format. + + :returns: The fully processed configuration file contents. If the file was unreadable, None will be returned. + :rtype: Optional["ConfigWizard"] + """ + # open the file + try: + # pylint: disable-next=consider-using-with; using a with would make exception handling even more ugly + file = open(filepath, encoding="utf-8") + except FileNotFoundError: + _LOGGER.error("The configuration file cannot be found.") + file = None + except PermissionError: + _LOGGER.error( + "Permission denied when trying to read the configuration file." + ) + file = None + if not file: + return None + + # read the file + try: + data = read_json_or_yaml(file) + except ValueError as err: + _LOGGER.error( + "Configuration file must be valid JSON or YAML. The following errors occured:\n%s", + str(err), + ) + data = None + config = None + finally: + file.close() + + # parse the file + if data: + try: + config = cls.from_dict(data) + except errors.MissingFields as err: + _LOGGER.error( + "Configuration is missing required fields: \n%s", str(err) + ) + config = None + except errors.ParseError as err: + _LOGGER.error("Invalid configuration value provided:\n%s", str(err)) + config = None + else: + config = cls.from_dict({}) + + return config + + +def read_json_or_yaml(stream: TextIO) -> Dict[str, Any]: + """Read a file without knowing if it is JSON or YAML formatted. + + The file will first be assumed to be JSON formatted. If this fails, an attempt to parse the file with the YAML + parser will be made. If both of these fail, an exception will be raised that contains the exception strings returned + by both the parsers. + + :param stream: An IO stream that allows seeking. + :type stream: typing.TextIO + :returns: The parsed file contents. + :rtype: typing.Dict[str, typing.Any]: + :raises ValueError: If the IO stream is not seekable or if the file doesn't appear to be JSON or YAML formatted. + """ + exceptions: Dict[str, Union[None, ValueError, yaml.error.YAMLError]] = { + "JSON": None, + "YAML": None, + } + data: Dict[str, Any] + + # ensure we can rewind the file + if not stream.seekable(): + raise ValueError("The provided stream must be seekable.") + + # attempt to read json + try: + data = json.loads(stream.read()) + except ValueError as err: + exceptions["JSON"] = err + else: + return data + finally: + stream.seek(0) + + # attempt to read yaml + try: + data = yaml.safe_load(stream.read()) + except (yaml.error.YAMLError, ValueError) as err: + exceptions["YAML"] = err + else: + return data + + # neither json nor yaml + err_msg = "\n\n".join( + [key + " Parser Errors:\n" + str(val) for key, val in exceptions.items()] + ) + raise ValueError(err_msg) + + +def try_json_load(value: str) -> Any: + """Try parsing the value as JSON and silently ignore errors. + + :param value: The value on which a JSON load should be attempted. + :type value: str + :returns: Either the parsed JSON or the provided value. + :rtype: typing.Any + """ + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + +def update_dict( + data: Dict[str, Any], + path: Tuple[str, ...], + value: Any, + overwrite: bool = False, +) -> None: + """Update a dictionary with a new value at a given path. + + :param data: The dictionary to be updated. + :type data: Dict[str, Any] + :param path: The path to the key that should be updated. + :type path: Tuple[str, ...] + :param value: The new value to be set at the specified path. + :type value: Any + :param overwrite: If True, overwrite the existing value. Otherwise, don't update if the key already exists. + :type overwrite: bool + :returns: None + """ + end = len(path) + target = data + for idx, key in enumerate(path, 1): + # on the last field in path, update the dict if necessary + if idx == end: + if overwrite or not target.get(key): + target[key] = value + return + + # verify the next hop exists + if not target.get(key): + target[key] = {} + + # if the next hop is not a dict, exit + if not isinstance(target.get(key), dict): + return + + # get next hop + target = target.get(key) # type: ignore[assignment] # type has already been enforced. diff --git a/RetrievalAugmentedGeneration/chain_server/openapi_schema.json b/RetrievalAugmentedGeneration/chain_server/openapi_schema.json new file mode 100644 index 000000000..ce4779156 --- /dev/null +++ b/RetrievalAugmentedGeneration/chain_server/openapi_schema.json @@ -0,0 +1,244 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "FastAPI", + "version": "0.1.0" + }, + "paths": { + "/uploadDocument": { + "post": { + "summary": "Upload Document", + "description": "Upload a document to the vector store.", + "operationId": "upload_document_uploadDocument_post", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_upload_document_uploadDocument_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/generate": { + "post": { + "summary": "Generate Answer", + "description": "Generate and stream the response to the provided prompt.", + "operationId": "generate_answer_generate_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Prompt" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/documentSearch": { + "post": { + "summary": "Document Search", + "description": "Search for the most relevant documents for the given search parameters.", + "operationId": "document_search_documentSearch_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DocumentSearch" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { + "type": "object" + }, + "type": "array", + "title": "Response Document Search Documentsearch Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "Body_upload_document_uploadDocument_post": { + "properties": { + "file": { + "type": "string", + "format": "binary", + "title": "File" + } + }, + "type": "object", + "required": [ + "file" + ], + "title": "Body_upload_document_uploadDocument_post" + }, + "DocumentSearch": { + "properties": { + "content": { + "type": "string", + "title": "Content", + "description": "The content or keywords to search for within documents." + }, + "num_docs": { + "type": "integer", + "title": "Num Docs", + "description": "The maximum number of documents to return in the response.", + "default": 4 + } + }, + "type": "object", + "required": [ + "content" + ], + "title": "DocumentSearch", + "description": "Definition of the DocumentSearch API data type." + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "type": "array", + "title": "Detail" + } + }, + "type": "object", + "title": "HTTPValidationError" + }, + "Prompt": { + "properties": { + "question": { + "type": "string", + "title": "Question", + "description": "The input query/prompt to the pipeline." + }, + "context": { + "type": "string", + "title": "Context", + "description": "Additional context for the question (optional)" + }, + "use_knowledge_base": { + "type": "boolean", + "title": "Use Knowledge Base", + "description": "Whether to use a knowledge base", + "default": true + }, + "num_tokens": { + "type": "integer", + "title": "Num Tokens", + "description": "The maximum number of tokens in the response.", + "default": 50 + } + }, + "type": "object", + "required": [ + "question", + "context" + ], + "title": "Prompt", + "description": "Definition of the Prompt API data type." + }, + "ValidationError": { + "properties": { + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "type": "array", + "title": "Location" + }, + "msg": { + "type": "string", + "title": "Message" + }, + "type": { + "type": "string", + "title": "Error Type" + } + }, + "type": "object", + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError" + } + } + } + } \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/chain_server/server.py b/RetrievalAugmentedGeneration/chain_server/server.py new file mode 100644 index 000000000..ea7f3176e --- /dev/null +++ b/RetrievalAugmentedGeneration/chain_server/server.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The definition of the Llama Index chain server.""" +import base64 +import os +import shutil +from pathlib import Path +from typing import Any, Dict, List + +from fastapi import FastAPI, File, UploadFile +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel, Field + +from chain_server import utils +from chain_server import chains + +# create the FastAPI server +app = FastAPI() +# prestage the embedding model +_ = utils.get_embedding_model() +# set the global service context for Llama Index +utils.set_service_context() + + +class Prompt(BaseModel): + """Definition of the Prompt API data type.""" + + question: str = Field(description="The input query/prompt to the pipeline.") + context: str = Field(description="Additional context for the question (optional)") + use_knowledge_base: bool = Field(description="Whether to use a knowledge base", default=True) + num_tokens: int = Field(description="The maximum number of tokens in the response.", default=50) + + +class DocumentSearch(BaseModel): + """Definition of the DocumentSearch API data type.""" + + content: str = Field(description="The content or keywords to search for within documents.") + num_docs: int = Field(description="The maximum number of documents to return in the response.", default=4) + + +@app.post("/uploadDocument") +async def upload_document(file: UploadFile = File(...)) -> JSONResponse: + """Upload a document to the vector store.""" + if not file.filename: + return JSONResponse(content={"message": "No files provided"}, status_code=200) + + upload_folder = "uploaded_files" + upload_file = os.path.basename(file.filename) + if not upload_file: + raise RuntimeError("Error parsing uploaded filename.") + file_path = os.path.join(upload_folder, upload_file) + uploads_dir = Path(upload_folder) + uploads_dir.mkdir(parents=True, exist_ok=True) + + with open(file_path, "wb") as f: + shutil.copyfileobj(file.file, f) + + chains.ingest_docs(file_path, upload_file) + + return JSONResponse( + content={"message": "File uploaded successfully"}, status_code=200 + ) + + +@app.post("/generate") +async def generate_answer(prompt: Prompt) -> StreamingResponse: + """Generate and stream the response to the provided prompt.""" + if prompt.use_knowledge_base: + generator = chains.rag_chain(prompt.question, prompt.num_tokens) + return StreamingResponse(generator, media_type="text/event-stream") + + generator = chains.llm_chain(prompt.context, prompt.question, prompt.num_tokens) + return StreamingResponse(generator, media_type="text/event-stream") + + +@app.post("/documentSearch") +def document_search(data: DocumentSearch) -> List[Dict[str, Any]]: + """Search for the most relevant documents for the given search parameters.""" + retriever = utils.get_doc_retriever(num_nodes=data.num_docs) + nodes = retriever.retrieve(data.content) + output = [] + for node in nodes: + file_name = nodes[0].metadata["filename"] + decoded_filename = base64.b64decode(file_name.encode("utf-8")).decode("utf-8") + entry = {"score": node.score, "source": decoded_filename, "content": node.text} + output.append(entry) + + return output diff --git a/RetrievalAugmentedGeneration/chain_server/trt_llm.py b/RetrievalAugmentedGeneration/chain_server/trt_llm.py new file mode 100644 index 000000000..b4f7746d9 --- /dev/null +++ b/RetrievalAugmentedGeneration/chain_server/trt_llm.py @@ -0,0 +1,557 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Langchain LLM component for connecting to Triton + TensorRT LLM backend.""" +# pylint: disable=too-many-lines +import abc +import json +import logging +import queue +import random +import time +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import google.protobuf.json_format +import numpy as np +import tritonclient.grpc as grpcclient +import tritonclient.http as httpclient +from tritonclient.grpc.service_pb2 import ModelInferResponse +from tritonclient.utils import np_to_triton_dtype + +try: + from langchain.callbacks.manager import CallbackManagerForLLMRun + from langchain.llms.base import LLM + from langchain.pydantic_v1 import Field, root_validator + + USE_LANGCHAIN = True +except ImportError: + USE_LANGCHAIN = False + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +STOP_WORDS = [""] +RANDOM_SEED = 0 + +if USE_LANGCHAIN: + # pylint: disable-next=too-few-public-methods # Interface is defined by LangChain + class TensorRTLLM(LLM): # LLM class not typed in langchain + """A custom Langchain LLM class that integrates with TRTLLM triton models. + + Arguments: + server_url: (str) The URL of the Triton inference server to use. + model_name: (str) The name of the Triton TRT model to use. + temperature: (str) Temperature to use for sampling + top_p: (float) The top-p value to use for sampling + top_k: (float) The top k values use for sampling + beam_width: (int) Last n number of tokens to penalize + repetition_penalty: (int) Last n number of tokens to penalize + length_penalty: (float) The penalty to apply repeated tokens + tokens: (int) The maximum number of tokens to generate. + client: The client object used to communicate with the inference server + """ + + server_url: str = Field(None, alias="server_url") + + # # all the optional arguments + model_name: str = "ensemble" + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 0 + top_k: Optional[int] = 1 + tokens: Optional[int] = 100 + beam_width: Optional[int] = 1 + repetition_penalty: Optional[float] = 1.0 + length_penalty: Optional[float] = 1.0 + client: Any + streaming: Optional[bool] = True + + @root_validator() # typing not declared in langchain + @classmethod + def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate that python package exists in environment.""" + try: + if values.get("streaming", True): + values["client"] = GrpcTritonClient(values["server_url"]) + else: + values["client"] = HttpTritonClient(values["server_url"]) + + except ImportError as err: + raise ImportError( + "Could not import triton client python package. " + "Please install it with `pip install tritonclient[all]`." + ) from err + return values + + @property + def _get_model_default_parameters(self) -> Dict[str, Any]: + return { + "tokens": self.tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "repetition_penalty": self.repetition_penalty, + "length_penalty": self.length_penalty, + "beam_width": self.beam_width, + } + + @property + def _invocation_params(self, **kwargs: Any) -> Dict[str, Any]: + params = {**self._get_model_default_parameters, **kwargs} + return params + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get all the identifying parameters.""" + return { + "server_url": self.server_url, + "model_name": self.model_name, + } + + @property + def _llm_type(self) -> str: + return "triton_tensorrt" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, # pylint: disable=unused-argument + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """ + Execute an inference request. + + Args: + prompt: The prompt to pass into the model. + stop: A list of strings to stop generation when encountered + + Returns: + The string generated by the model + """ + text_callback = None + if run_manager: + text_callback = partial( + run_manager.on_llm_new_token, verbose=self.verbose + ) + + invocation_params = self._get_model_default_parameters + invocation_params.update(kwargs) + invocation_params["prompt"] = [[prompt]] + model_params = self._identifying_params + model_params.update(kwargs) + request_id = str(random.randint(1, 9999999)) # nosec + + self.client.load_model(model_params["model_name"]) + if isinstance(self.client, GrpcTritonClient): + return self._streaming_request( + model_params, request_id, invocation_params, text_callback + ) + return self._request(model_params, invocation_params, text_callback) + + def _streaming_request( + self, + model_params: Dict[str, Any], + request_id: str, + invocation_params: Dict[str, Any], + text_callback: Optional[Callable[[str], None]], + ) -> str: + """Request a streaming inference session.""" + result_queue = self.client.request_streaming( + model_params["model_name"], request_id, **invocation_params + ) + + response = "" + start_time = time.time() + tokens_generated = 0 + for token in result_queue: + if text_callback: + text_callback(token) + tokens_generated += 1 + response = response + token + total_time = time.time() - start_time + logger.info( + "\n--- Generated %s tokens in %s seconds ---", + tokens_generated, + total_time, + ) + logger.info("--- %s tokens/sec", tokens_generated / total_time) + return response + + def _request( + self, + model_params: Dict[str, Any], + invocation_params: Dict[str, Any], + text_callback: Optional[Callable[[str], None]], + ) -> str: + """Request a streaming inference session.""" + token: str = self.client.request( + model_params["model_name"], **invocation_params + ) + if text_callback: + text_callback(token) + return token + + +class StreamingResponseGenerator(queue.Queue[Optional[str]]): + """A Generator that provides the inference results from an LLM.""" + + def __init__( + self, client: "GrpcTritonClient", request_id: str, force_batch: bool + ) -> None: + """Instantiate the generator class.""" + super().__init__() + self._client = client + self.request_id = request_id + self._batch = force_batch + + def __iter__(self) -> "StreamingResponseGenerator": + """Return self as a generator.""" + return self + + def __next__(self) -> str: + """Return the next retrieved token.""" + val = self.get() + if val is None or val in STOP_WORDS: + self._stop_stream() + raise StopIteration() + return val + + def _stop_stream(self) -> None: + """Drain and shutdown the Triton stream.""" + self._client.stop_stream( + "tensorrt_llm", self.request_id, signal=not self._batch + ) + + +class _BaseTritonClient(abc.ABC): + """An abstraction of the connection to a triton inference server.""" + + def __init__(self, server_url: str) -> None: + """Initialize the client.""" + self._server_url = server_url + self._client = self._inference_server_client(server_url) + + @property + @abc.abstractmethod + def _inference_server_client( + self, + ) -> Union[ + Type[grpcclient.InferenceServerClient], Type[httpclient.InferenceServerClient] + ]: + """Return the prefered InferenceServerClient class.""" + + @property + @abc.abstractmethod + def _infer_input( + self, + ) -> Union[Type[grpcclient.InferInput], Type[httpclient.InferInput]]: + """Return the preferred InferInput.""" + + @property + @abc.abstractmethod + def _infer_output( + self, + ) -> Union[ + Type[grpcclient.InferRequestedOutput], Type[httpclient.InferRequestedOutput] + ]: + """Return the preferred InferRequestedOutput.""" + + def load_model(self, model_name: str, timeout: int = 1000) -> None: + """Load a model into the server.""" + if self._client.is_model_ready(model_name): + return + + self._client.load_model(model_name) + t0 = time.perf_counter() + t1 = t0 + while not self._client.is_model_ready(model_name) and t1 - t0 < timeout: + t1 = time.perf_counter() + + if not self._client.is_model_ready(model_name): + raise RuntimeError(f"Failed to load {model_name} on Triton in {timeout}s") + + def get_model_list(self) -> List[str]: + """Get a list of models loaded in the triton server.""" + res = self._client.get_model_repository_index(as_json=True) + return [model["name"] for model in res["models"]] + + def get_model_concurrency(self, model_name: str, timeout: int = 1000) -> int: + """Get the modle concurrency.""" + self.load_model(model_name, timeout) + instances = self._client.get_model_config(model_name, as_json=True)["config"][ + "instance_group" + ] + return sum(instance["count"] * len(instance["gpus"]) for instance in instances) + + def _generate_stop_signals( + self, + ) -> List[Union[grpcclient.InferInput, httpclient.InferInput]]: + """Generate the signal to stop the stream.""" + inputs = [ + self._infer_input("input_ids", [1, 1], "INT32"), + self._infer_input("input_lengths", [1, 1], "INT32"), + self._infer_input("request_output_len", [1, 1], "UINT32"), + self._infer_input("stop", [1, 1], "BOOL"), + ] + inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32)) + inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32)) + inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32)) + inputs[3].set_data_from_numpy(np.array([[True]], dtype="bool")) + return inputs + + def _generate_outputs( + self, + ) -> List[Union[grpcclient.InferRequestedOutput, httpclient.InferRequestedOutput]]: + """Generate the expected output structure.""" + return [self._infer_output("text_output")] + + def _prepare_tensor( + self, name: str, input_data: Any + ) -> Union[grpcclient.InferInput, httpclient.InferInput]: + """Prepare an input data structure.""" + t = self._infer_input( + name, input_data.shape, np_to_triton_dtype(input_data.dtype) + ) + t.set_data_from_numpy(input_data) + return t + + def _generate_inputs( # pylint: disable=too-many-arguments,too-many-locals + self, + prompt: str, + tokens: int = 300, + temperature: float = 1.0, + top_k: float = 1, + top_p: float = 0, + beam_width: int = 1, + repetition_penalty: float = 1, + length_penalty: float = 1.0, + stream: bool = True, + ) -> List[Union[grpcclient.InferInput, httpclient.InferInput]]: + """Create the input for the triton inference server.""" + query = np.array(prompt).astype(object) + request_output_len = np.array([tokens]).astype(np.uint32).reshape((1, -1)) + runtime_top_k = np.array([top_k]).astype(np.uint32).reshape((1, -1)) + runtime_top_p = np.array([top_p]).astype(np.float32).reshape((1, -1)) + temperature_array = np.array([temperature]).astype(np.float32).reshape((1, -1)) + len_penalty = np.array([length_penalty]).astype(np.float32).reshape((1, -1)) + repetition_penalty_array = ( + np.array([repetition_penalty]).astype(np.float32).reshape((1, -1)) + ) + random_seed = np.array([RANDOM_SEED]).astype(np.uint64).reshape((1, -1)) + beam_width_array = np.array([beam_width]).astype(np.uint32).reshape((1, -1)) + streaming_data = np.array([[stream]], dtype=bool) + + inputs = [ + self._prepare_tensor("text_input", query), + self._prepare_tensor("max_tokens", request_output_len), + self._prepare_tensor("top_k", runtime_top_k), + self._prepare_tensor("top_p", runtime_top_p), + self._prepare_tensor("temperature", temperature_array), + self._prepare_tensor("length_penalty", len_penalty), + self._prepare_tensor("repetition_penalty", repetition_penalty_array), + self._prepare_tensor("random_seed", random_seed), + self._prepare_tensor("beam_width", beam_width_array), + self._prepare_tensor("stream", streaming_data), + ] + return inputs + + def _trim_batch_response(self, result_str: str) -> str: + """Trim the resulting response from a batch request by removing provided prompt and extra generated text.""" + # extract the generated part of the prompt + split = result_str.split("[/INST]", 1) + generated = split[-1] + end_token = generated.find("") + if end_token == -1: + return generated + generated = generated[:end_token].strip() + return generated + + +class GrpcTritonClient(_BaseTritonClient): + """GRPC connection to a triton inference server.""" + + @property + def _inference_server_client( + self, + ) -> Type[grpcclient.InferenceServerClient]: + """Return the prefered InferenceServerClient class.""" + return grpcclient.InferenceServerClient # type: ignore + + @property + def _infer_input(self) -> Type[grpcclient.InferInput]: + """Return the preferred InferInput.""" + return grpcclient.InferInput # type: ignore + + @property + def _infer_output( + self, + ) -> Type[grpcclient.InferRequestedOutput]: + """Return the preferred InferRequestedOutput.""" + return grpcclient.InferRequestedOutput # type: ignore + + def _send_stop_signals(self, model_name: str, request_id: str) -> None: + """Send the stop signal to the Triton Inference server.""" + stop_inputs = self._generate_stop_signals() + self._client.async_stream_infer( + model_name, + stop_inputs, + request_id=request_id, + parameters={"Streaming": True}, + ) + + @staticmethod + def _process_result(result: Dict[str, str]) -> str: + """Post-process the result from the server.""" + message = ModelInferResponse() + generated_text: str = "" + google.protobuf.json_format.Parse(json.dumps(result), message) + infer_result = grpcclient.InferResult(message) + np_res = infer_result.as_numpy("text_output") + + generated_text = "" + if np_res is not None: + generated_text = "".join([token.decode() for token in np_res]) + + return generated_text + + def _stream_callback( + self, + result_queue: queue.Queue[Union[Optional[Dict[str, str]], str]], + force_batch: bool, + result: Any, + error: str, + ) -> None: + """Add streamed result to queue.""" + if error: + result_queue.put(error) + else: + response_raw = result.get_response(as_json=True) + if "outputs" in response_raw: + # the very last response might have no output, just the final flag + response = self._process_result(response_raw) + if force_batch: + response = self._trim_batch_response(response) + + if response in STOP_WORDS: + result_queue.put(None) + else: + result_queue.put(response) + + if response_raw["parameters"]["triton_final_response"]["bool_param"]: + # end of the generation + result_queue.put(None) + + # pylint: disable-next=too-many-arguments + def _send_prompt_streaming( + self, + model_name: str, + request_inputs: Any, + request_outputs: Optional[Any], + request_id: str, + result_queue: StreamingResponseGenerator, + force_batch: bool = False, + ) -> None: + """Send the prompt and start streaming the result.""" + self._client.start_stream( + callback=partial(self._stream_callback, result_queue, force_batch) + ) + self._client.async_stream_infer( + model_name=model_name, + inputs=request_inputs, + outputs=request_outputs, + request_id=request_id, + ) + + def request_streaming( + self, + model_name: str, + request_id: Optional[str] = None, + force_batch: bool = False, + **params: Any, + ) -> StreamingResponseGenerator: + """Request a streaming connection.""" + if not self._client.is_model_ready(model_name): + raise RuntimeError("Cannot request streaming, model is not loaded") + + if not request_id: + request_id = str(random.randint(1, 9999999)) # nosec + + result_queue = StreamingResponseGenerator(self, request_id, force_batch) + inputs = self._generate_inputs(stream=not force_batch, **params) + outputs = self._generate_outputs() + self._send_prompt_streaming( + model_name, + inputs, + outputs, + request_id, + result_queue, + force_batch, + ) + return result_queue + + def stop_stream( + self, model_name: str, request_id: str, signal: bool = True + ) -> None: + """Close the streaming connection.""" + if signal: + self._send_stop_signals(model_name, request_id) + self._client.stop_stream() + + +class HttpTritonClient(_BaseTritonClient): + """HTTP connection to a triton inference server.""" + + @property + def _inference_server_client( + self, + ) -> Type[httpclient.InferenceServerClient]: + """Return the prefered InferenceServerClient class.""" + return httpclient.InferenceServerClient # type: ignore + + @property + def _infer_input(self) -> Type[httpclient.InferInput]: + """Return the preferred InferInput.""" + return httpclient.InferInput # type: ignore + + @property + def _infer_output( + self, + ) -> Type[httpclient.InferRequestedOutput]: + """Return the preferred InferRequestedOutput.""" + return httpclient.InferRequestedOutput # type: ignore + + def request( + self, + model_name: str, + **params: Any, + ) -> str: + """Request inferencing from the triton server.""" + if not self._client.is_model_ready(model_name): + raise RuntimeError("Cannot request streaming, model is not loaded") + + # create model inputs and outputs + inputs = self._generate_inputs(stream=False, **params) + outputs = self._generate_outputs() + + # call the model for inference + result = self._client.infer(model_name, inputs=inputs, outputs=outputs) + result_str = "".join( + [val.decode("utf-8") for val in result.as_numpy("text_output").tolist()] + ) + + # extract the generated part of the prompt + # return(result_str) + return self._trim_batch_response(result_str) diff --git a/RetrievalAugmentedGeneration/chain_server/utils.py b/RetrievalAugmentedGeneration/chain_server/utils.py new file mode 100644 index 000000000..1629c6362 --- /dev/null +++ b/RetrievalAugmentedGeneration/chain_server/utils.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for the LLM Chains.""" +import os +import base64 +from functools import lru_cache +from typing import TYPE_CHECKING, List, Optional + +import torch +from llama_index.indices.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import MetadataMode +from llama_index.utils import globals_helper +from llama_index.vector_stores import MilvusVectorStore +from llama_index import VectorStoreIndex, ServiceContext, set_global_service_context +from llama_index.llms import LangChainLLM +from llama_index import LangchainEmbedding +from langchain.text_splitter import SentenceTransformersTokenTextSplitter +from langchain.embeddings import HuggingFaceEmbeddings +from chain_server.trt_llm import TensorRTLLM +from chain_server import configuration + +if TYPE_CHECKING: + from llama_index.indices.base_retriever import BaseRetriever + from llama_index.indices.query.schema import QueryBundle + from llama_index.schema import NodeWithScore + from chain_server.configuration_wizard import ConfigWizard + +DEFAULT_MAX_CONTEXT = 1500 +DEFAULT_NUM_TOKENS = 50 + + +class LimitRetrievedNodesLength(BaseNodePostprocessor): + """Llama Index chain filter to limit token lengths.""" + + def postprocess_nodes( + self, nodes: List["NodeWithScore"], query_bundle: Optional["QueryBundle"] = None + ) -> List["NodeWithScore"]: + """Filter function.""" + included_nodes = [] + current_length = 0 + limit = DEFAULT_MAX_CONTEXT + + for node in nodes: + current_length += len( + globals_helper.tokenizer( + node.node.get_content(metadata_mode=MetadataMode.LLM) + ) + ) + if current_length > limit: + break + included_nodes.append(node) + + return included_nodes + + +@lru_cache +def set_service_context() -> None: + """Set the global service context.""" + service_context = ServiceContext.from_defaults( + llm=get_llm(), embed_model=get_embedding_model() + ) + set_global_service_context(service_context) + + +@lru_cache +def get_config() -> "ConfigWizard": + """Parse the application configuration.""" + config_file = os.environ.get("APP_CONFIG_FILE", "/dev/null") + config = configuration.AppConfig.from_file(config_file) + if config: + return config + raise RuntimeError("Unable to find configuration.") + + +@lru_cache +def get_vector_index() -> VectorStoreIndex: + """Create the vector db index.""" + config = get_config() + vector_store = MilvusVectorStore(uri=config.milvus.url, dim=1024, overwrite=False) + return VectorStoreIndex.from_vector_store(vector_store) + + +@lru_cache +def get_doc_retriever(num_nodes: int = 4) -> "BaseRetriever": + """Create the document retriever.""" + index = get_vector_index() + return index.as_retriever(similarity_top_k=num_nodes) + + +@lru_cache +def get_llm() -> LangChainLLM: + """Create the LLM connection.""" + settings = get_config() + trtllm = TensorRTLLM( # type: ignore + server_url=settings.triton.server_url, + model_name=settings.triton.model_name, + tokens=DEFAULT_NUM_TOKENS, + ) + return LangChainLLM(llm=trtllm) + + +@lru_cache +def get_embedding_model() -> LangchainEmbedding: + """Create the embedding model.""" + model_kwargs = {"device": "cpu"} + if torch.cuda.is_available(): + model_kwargs["device"] = "cuda:0" + + encode_kwargs = {"normalize_embeddings": False} + hf_embeddings = HuggingFaceEmbeddings( + model_name=get_config().embeddings.model_name, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs, + ) + + # Load in a specific embedding model + return LangchainEmbedding(hf_embeddings) + + +@lru_cache +def is_base64_encoded(s: str) -> bool: + """Check if a string is base64 encoded.""" + try: + # Attempt to decode the string as base64 + decoded_bytes = base64.b64decode(s) + # Encode the decoded bytes back to a string to check if it's valid + decoded_str = decoded_bytes.decode("utf-8") + # If the original string and the decoded string match, it's base64 encoded + return s == base64.b64encode(decoded_str.encode("utf-8")).decode("utf-8") + except Exception: # pylint:disable = broad-exception-caught + # An exception occurred during decoding, so it's not base64 encoded + return False + + +def get_text_splitter() -> SentenceTransformersTokenTextSplitter: + """Return the token text splitter instance from langchain.""" + return SentenceTransformersTokenTextSplitter( + model_name=get_config().embeddings.model_name, + chunk_size=get_config().text_splitter.chunk_size, + chunk_overlap=get_config().text_splitter.chunk_overlap, + ) diff --git a/RetrievalAugmentedGeneration/deploy/compose.env b/RetrievalAugmentedGeneration/deploy/compose.env new file mode 100644 index 000000000..4a88c7ad8 --- /dev/null +++ b/RetrievalAugmentedGeneration/deploy/compose.env @@ -0,0 +1,23 @@ +# full path to the local copy of the model weights +export MODEL_DIRECTORY="/home/nvidia/llama2_13b_chat_hf_v1/" + +# the architecture of the model. eg: llama +export MODEL_ARCHITECTURE="llama" + +# the name of the model being used - only for displaying on frontend +export MODEL_NAME="Llama-2-13b-chat-hf" + +# [OPTIONAL] the maximum number of input tokens +# export MODEL_MAX_INPUT_LENGTH=3000 + +# [OPTIONAL] the maximum number of output tokens +# export MODEL_MAX_OUTPUT_LENGTH=512 + +# [OPTIONAL] the number of GPUs to make available to the inference server +# export INFERENCE_GPU_COUNT="all" + +# [OPTIONAL] the base directory inside which all persistent volumes will be created +# export DOCKER_VOLUME_DIRECTORY="." + +# [OPTIONAL] the config file for chain server w.r.t. pwd +export APP_CONFIG_FILE=/dev/null diff --git a/RetrievalAugmentedGeneration/deploy/config.yaml b/RetrievalAugmentedGeneration/deploy/config.yaml new file mode 100644 index 000000000..62b70beb1 --- /dev/null +++ b/RetrievalAugmentedGeneration/deploy/config.yaml @@ -0,0 +1,52 @@ +milvus: + # The configuration of the Milvus connection. + + url: "http://milvus:19530" + # The location of the Milvus Server. + # Type: str + # ENV Variable: APP_MILVUS_URL + +triton: + # The configuration for the Triton server hosting the embedding models. + + server_url: triton:8001 + # The location of the Triton server hosting the embedding model. + # Type: str + # ENV Variable: APP_TRITON_SERVERURL + + model_name: ensemble + # The name of the hosted model. + # Type: str + # ENV Variable: APP_TRITON_MODELNAME + + +text_splitter: + # The configuration for the Text Splitter. + + chunk_size: 510 + # Chunk size for text splitting. + # Type: int + + chunk_overlap: 200 + # Overlapping text length for splitting. + # Type: int + +embeddings: + # The configuration embedding models. + + model_name: intfloat/e5-large-v2 + # The name embedding search model from huggingface. + # Type: str + +prompts: + # The configuration for the prompts used for response generation. + + chat_template: + [INST] <>You are a helpful, respectful and honest assistant.Always answer as helpfully as possible, while being safe.Please ensure that your responses are positive in nature.<>[/INST] {context_str} [INST] {query_str} [/INST] + # The chat prompt template guides the model to generate responses for queries. + # Type: str + + rag_template: + "[INST] <>Use the following context to answer the user's question. If you don't know the answer,just say that you don't know, don't try to make up an answer.<>[INST] Context: {context_str} Question: {query_str} Only return the helpful answer below and nothing else. Helpful answer:[/INST]" + # The RAG prompt template instructs the model to generate responses for queries while utilizing knowledge base. + # Type: str diff --git a/RetrievalAugmentedGeneration/deploy/docker-compose.yaml b/RetrievalAugmentedGeneration/deploy/docker-compose.yaml new file mode 100644 index 000000000..8c133ef3e --- /dev/null +++ b/RetrievalAugmentedGeneration/deploy/docker-compose.yaml @@ -0,0 +1,174 @@ +services: + + triton: + container_name: triton-inference-server + image: llm-inference-server:latest + build: + context: ../llm-inference-server/ + dockerfile: Dockerfile + volumes: + - ${MODEL_DIRECTORY:?please update the env file and source it before running}:/model + command: ${MODEL_ARCHITECTURE:?please update the env file and source it before running} --max-input-length ${MODEL_MAX_INPUT_LENGTH:-3000} --max-output-length ${MODEL_MAX_OUTPUT_LENGTH:-512} + ports: + - "8000:8000" + - "8001:8001" + - "8002:8002" + expose: + - "8000" + - "8001" + - "8002" + shm_size: 20gb + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: ${INFERENCE_GPU_COUNT:-all} + capabilities: [gpu] + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/v2/health/ready"] + interval: 30s + timeout: 20s + retries: 3 + start_period: 10m + + jupyter-server: + container_name: jupyter-notebook-server + image: notebook-server:latest + build: + context: ../ + dockerfile: Dockerfile.notebooks + ports: + - "8888:8888" + expose: + - "8888" + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + depends_on: + - "triton" + etcd: + container_name: milvus-etcd + image: quay.io/coreos/etcd:v3.5.5 + environment: + - ETCD_AUTO_COMPACTION_MODE=revision + - ETCD_AUTO_COMPACTION_RETENTION=1000 + - ETCD_QUOTA_BACKEND_BYTES=4294967296 + - ETCD_SNAPSHOT_COUNT=50000 + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd + command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd + healthcheck: + test: ["CMD", "etcdctl", "endpoint", "health"] + interval: 30s + timeout: 20s + retries: 3 + + minio: + container_name: milvus-minio + image: minio/minio:RELEASE.2023-03-20T20-16-18Z + environment: + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin + ports: + - "9001:9001" + - "9000:9000" + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data + command: minio server /minio_data --console-address ":9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 30s + timeout: 20s + retries: 3 + + milvus: + container_name: milvus-standalone + image: milvusdb/milvus:v2.3.1-gpu + command: ["milvus", "run", "standalone"] + environment: + ETCD_ENDPOINTS: etcd:2379 + MINIO_ADDRESS: minio:9000 + KNOWHERE_GPU_MEM_POOL_SIZE: 2048:4096 + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] + interval: 30s + start_period: 90s + timeout: 20s + retries: 3 + ports: + - "19530:19530" + - "9091:9091" + depends_on: + - "etcd" + - "minio" + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: ["gpu"] + count: 1 + + query: + container_name: query-router + image: chain-server:latest + build: + context: ../ + dockerfile: Dockerfile + command: --port 8081 --host 0.0.0.0 + environment: + APP_MILVUS_URL: "http://milvus:19530" + APP_TRITON_SERVERURL: "triton:8001" + APP_TRITON_MODELNAME: ensemble + APP_CONFIG_FILE: ${APP_CONFIG_FILE} + volumes: + - ${APP_CONFIG_FILE}:${APP_CONFIG_FILE} + ports: + - "8081:8081" + expose: + - "8081" + shm_size: 5gb + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + # healthcheck: + # test: ["CMD", "curl", "-f", "http://localhost:8080/"] + # interval: 30s + # timeout: 20s + # retries: 3 + depends_on: + - "milvus" + - "triton" + + frontend: + container_name: llm-playground + image: llm-playground:latest + build: + context: ../frontend/ + dockerfile: Dockerfile + command: --port 8090 + environment: + APP_SERVERURL: http://query + APP_SERVERPORT: 8081 + APP_MODELNAME: ${MODEL_NAME:-${MODEL_ARCHITECTURE}} + ports: + - "8090:8090" + expose: + - "8090" + depends_on: + - query + +networks: + default: + name: nvidia-llm diff --git a/RetrievalAugmentedGeneration/docs/architecture.md b/RetrievalAugmentedGeneration/docs/architecture.md new file mode 100644 index 000000000..4fd6e55de --- /dev/null +++ b/RetrievalAugmentedGeneration/docs/architecture.md @@ -0,0 +1,89 @@ + +Overview +================================= + +Generative AI enables users to quickly generate new content based on a variety of inputs and is a powerful tool for streamlining the workflow of creatives, engineers, researchers, scientists, and more. The use cases and possibilities span all industries and individuals. Generative AI models can produce novel content like stories, emails, music, images, and videos. + +Here at NVIDIA, we like to utilize our own products to make our lives easier, so we have used generative AI to create an NVIDIA chatbot enhanced with retrieval augmented generation (RAG). This chatbot is designed to assist an NVIDIA employee with answering public relations related questions. The sample dataset includes the last two years of NVIDIA press releases and corporate blog posts. Our development and deployment of that chatbot is the guide to this reference generative AI workflow. + +Generative AI starts with foundational models trained on vast quantities of unlabeled data. **Large language models (LLMs)** are trained on an extensive range of textual data online. These LLMs can understand prompts and generate novel, human-like responses. Businesses can build applications to leverage this capability of LLMs; for example creative writing assistants for marketing, document summarization for legal teams, and code writing for software development. + +To create true business value from LLMs, these foundational models need to be tailored to your enterprise use case. In this workflow, we use [RAG](https://blog.langchain.dev/tutorial-chatgpt-over-your-data/) with [Llama2](https://github.com/facebookresearch/llama/), an open source model from Meta, to achieve this. Augmenting an existing AI foundational model provides an advanced starting point and a low-cost solution that enterprises can leverage to generate accurate and clear responses to their specific use case. + +This RAG-based reference chatbot workflow contains: + + - [NVIDIA NeMo framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/index.html) - part of NVIDIA AI Enterprise solution + - [NVIDIA TensorRT-LLM](https://developer.nvidia.com/tensorrt) - for low latency and high throughput inference for LLMs + - [LangChain](https://github.com/langchain-ai/langchain/) and [LlamaIndex](https://www.llamaindex.ai/) for combining language model components and easily constructing question-answering from a company's database + - [Sample Jupyter Notebooks](jupyter_server.md) and [chatbot web application/API calls](./frontend.md) so that you can test the chat system in an interactive manner + - [Milvus](https://milvus.io/docs/install_standalone-docker.md) - Generated embeddings are stored in a vector database. The vector DB used in this workflow is Milvus. Milvus is an open-source vector database capable of NVIDIA GPU accelerated vector searches. + - [e5-large-v2 model](https://huggingface.co/embaas/sentence-transformers-e5-large-v2) from huggingface to generate the embeddings. + - [Llama2](https://github.com/facebookresearch/llama/), an open source model from Meta, to formulate natural responses. + +This RAG chatbot workflow provides a reference for you to build your own enterprise AI solution with minimal effort. This AI workflow was designed to be deployed as a Developer experience using Docker Compose on an NVIDIA AI Enterprise-supported platform, which can be deployed on-prem or using a cloud service provider (CSP). Workflow components are used to deploy models and inference pipeline, integrated together with the additional components as indicated in the diagram below: + +![Diagram](./../images/image0.png) + +NVIDIA AI Components +====================== +This reference workflow uses a variety of NVIDIA AI components to customize and deploy the RAG-based chatbot example. + + - NVIDIA TensorRT-LLM + - NVIDIA NeMo Inference Server + +The following sections describe these NVIDIA AI components further. + +**NVIDIA TensorRT-LLM Optimization** + +A LLM can be optimized using TensorRT-LLM. NVIDIA NeMo uses TensorRT for LLMs (TensorRT-LLM), for deployment which accelerates and maximizes inference performance on the latest LLMs. +In this workflow, we will be leveraging a Llama 2 (13B parameters) chat model. We will convert the foundational model to TensorRT format using TensorRT-LLM for optimized inference. + +**NVIDIA NeMo Framework Inference Server** + +With NeMo Framework Inference Server, the optimized LLM can be deployed for high-performance, cost-effective, and low-latency inference. NVIDIA NGC is used as model storage in this workflow, but you are free to choose different model storage solutions like MLFlow or AWS SageMaker. +The Triton Inference Server uses models stored in a model repository, available locally to serve inference requests. Once they are available in Triton, inference requests are sent from a client application. Python and C++ libraries provide APIs to simplify communication. Clients send HTTP/REST requests directly to Triton using HTTP/REST or gRPC protocols. + +Within this workflow, the Llama2 LLM was optimized using NVIDIA TensorRT for LLMs (TRT-LLM) which accelerates and maximizes inference performance on the latest LLMs. + +Inference Pipeline +==================== +To get started with the inferencing pipeline, we will first connect the customized LLM to a sample proprietary data source. This knowledge can come in many forms: product specifications, HR documents, or finance spreadsheets. Enhancing the model’s capabilities with this knowledge can be done with RAG. + +Since foundational LLMs are not trained on your proprietary enterprise data and are only trained up to a fixed point in time, they need to be augmented with additional data. RAG consists of two processes. First, *retrieval* of data from document repositories, databases, or APIs that are all outside of the foundational model’s knowledge. Second, is the *generation* of responses via Inference. The example used within this workflow is a corporate communications co-pilot that could either ingest source data from storage or by scraping. The following graphic describes an overview of this inference pipeline: + +![Diagram](./../images/image1.png) + +**Document Ingestion and Retrieval** + +RAG begins with a knowledge base of relevant up-to-date information. Since data within an enterprise is frequently updated, the ingestion of documents into a knowledge base should be a recurring process and scheduled as a job. Next, content from the knowledge base is passed to an embedding model (e5-large-v2, in the case of this workflow), which converts the content to vectors (referred to as “embeddings”). Generating embeddings is a critical step in RAG; it allows for the dense numerical representations of textual information. These embeddings are stored in a vector database, in this case Milvus, which is [RAFT accelerated](https://developer.nvidia.com/blog/accelerating-vector-search-using-gpu-powered-indexes-with-rapids-raft). + +**User Query and Response Generation** + +When a user query is sent to the inference server, it is converted to an embedding using the embedding model. This is the same embedding model used to convert the documents in the knowledge base (e5-large-v2, in the case of this workflow). The database performs a similarity/semantic search to find the vectors that most closely resemble the user’s intent and provides them to the LLM as enhanced context. Since Milvus is RAFT accelerated, the similarity serach is optimized on the GPU. Lastly, the LLM is used to generate a full answer that’s streamed to the user. This is all done with ease via [LangChain](https://github.com/langchain-ai/langchain/) and [LlamaIndex](https://www.llamaindex.ai) + +The following diagram illustrates the ingestion of documents and generation of responses. + +![Diagram](./../images/image2.png) + +LangChain allows you to write LLM wrappers for your own custom LLMs, so we have provided a sample wrapper for streaming responses from a TensorRT-LLM Llama 2 model running on Triton Inference Server. This wrapper allows us to leverage LangChain’s standard interface for interacting with LLMs while still achieving vast performance speedup from TensorRT-LLM and scalable and flexible inference from Triton Inference Server. + +A sample chatbot web application is provided in the workflow so that you can test the chat system in an interactive manner. Requests to the chat system are wrapped in API calls, so these can be abstracted to other applications. + +An additional method of customization in the AI Workflow inference pipeline is via a prompt template. A prompt template is a pre-defined recipe for generating prompts for language models. They may contain instructions, few-shot examples, and context appropriate for a given task. In our example, we prompt our model to generate safe and polite responses. + + +**Triton Model Server** + +The Triton Inference Server uses models stored in a model repository, available locally to serve inference requests. Once they are available in Triton, inference requests are sent from a client application. Python and C++ libraries provide APIs to simplify communication. Clients send HTTP/REST requests directly to Triton using HTTP/REST or gRPC protocols. + +Within this workflow, the Llama2 LLM was optimized using NVIDIA TensorRT for LLMs (TRT-LLM) which accelerates and maximizes inference performance on the latest LLMs. + +**Vector DB** + +Milvus is an open-source vector database built to power embedding similarity search and AI applications. It makes unstructured data from API calls, PDFs, and other documents more accessible by storing them as embeddings. +When content from the knowledge base is passed to an embedding model (e5-large-v2), it converts the content to vectors (referred to as “embeddings”). These embeddings are stored in a vector database. The vector DB used in this workflow is Milvus. Milvus is an open-source vector database capable of NVIDIA GPU accelerated vector searches. + +*Note::* +``` +If needed, see Milvus's [documentation](https://milvus.io/docs/install_standalone-docker.md/) for how a Docker Compose file can be configured for Milvus. +``` \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/docs/chat_server.md b/RetrievalAugmentedGeneration/docs/chat_server.md new file mode 100644 index 000000000..ffdedbd26 --- /dev/null +++ b/RetrievalAugmentedGeneration/docs/chat_server.md @@ -0,0 +1,182 @@ + +# Chat Server +A sample fastapi based server is provided in the workflow so that you can test the chat system in an interactive manner. +This server wraps calls made to different components and orchestrates the entire flow. + +This API endpoint allows for several actions: +- [Chat Server](#chat-server) + - [Upload File Endpoint](#upload-file-endpoint) + - [Answer Generation Endpoint](#answer-generation-endpoint) + - [Document Search Endpoint](#document-search-endpoint) +- [Running the chain server](#running-the-chain-server) + +The API server swagger schema can be visualized at ``host-ip:8081/docs``. +You can checkout the openapi standard compatible schema for the endpoints supported [here](../chain_server/openapi_schema.json). + +The following sections describe the API endpoint actions further with relevant examples. + +### Upload File Endpoint +**Summary:** Upload a file. This endpoint should accept a post request with the following JSON in the body: + +```json +{ + "file": (file_path, file_binary_data, mime_type), +} +``` + +The response should be in JSON form. It should be a dictionary with a confirmation message: + +```json +{"message": "File uploaded successfully"} +``` + +**Endpoint:** ``/uploadDocument`` + +**HTTP Method:** POST + +**Request:** + +- **Content-Type:** multipart/form-data +- **Schema:** ``Body_upload_file_uploadDocument_post`` +- **Required:** Yes + +**Request Body Parameters:** +- ``file`` (Type: File) - The file to be uploaded. + +**Responses:** +- **200 - Successful Response** + + - Description: The file was successfully uploaded. + - Response Body: Empty + +- **422 - Validation Error** + + - Description: There was a validation error with the request. + - Response Body: Details of the validation error. + + + +### Answer Generation Endpoint +**Summary:** Generate an answer to a question. This endpoint should accept a post request with the following JSON content in the body: + +```json +{ + "question": "USER PROMPT", // A string of the prompt provided by the user + "context": "Conversation context to provide to the model.", + "use_knowledge_base": false, // A boolean flag to toggle VectorDB lookups + "num_tokens": 500, // The maximum number of tokens expected in the response. +} +``` + +The response should in JSON form. It should simply be a string of the response. + +```json +"LLM response" +``` + +The chat server must also handle responses being retrieved in chunks as opposed to all at once. The client code for response streaming looks like this: + +```python +with requests.post(url, stream=True, json=data, timeout=10) as req: + for chunk in req.iter_content(16): + yield chunk.decode("UTF-8") +``` + +**Endpoint:** ``/generate`` + +**HTTP Method:** POST + +**Operation ID:** ``generate_answer_generate_post`` + +**Request:** + +- **Content-Type:** application/json +- **Schema:** ``Prompt`` +- **Required:** Yes + +**Request Body Parameters:** + +- ``question`` (Type: string) - The question you want to ask. +- ``context`` (Type: string) - Additional context for the question (optional). +- ``use_knowledge_base`` (Type: boolean, Default: true) - Whether to use a knowledge base. +- ``num_tokens`` (Type: integer, Default: 500) - The maximum number of tokens in the response. + +**Responses:** + +- **200 - Successful Response** + + - Description: The answer was successfully generated. + - Response Body: An object containing the generated answer. + +- **422 - Validation Error** + + - Description: There was a validation error with the request. + - Response Body: Details of the validation error. + +### Document Search Endpoint +**Summary:** Search for documents based on content. This endpoint should accept a post request with the following JSON content in the body: + +```json +{ + "content": "USER PROMPT", // A string of the prompt provided by the user + "num_docs": "4", // An integer indicating how many documents should be returned +} +``` + +The response should in JSON form. It should be a list of dictionaries containing the document score and content. + +```json +[ + { + "score": 0.89123, + "content": "The content of the relevant chunks from the vector db.", + }, + ... +] +``` + + +**Endpoint:** ``/documentSearch`` +**HTTP Method:** POST + +**Operation ID:** ``document_search_documentSearch_post`` + +**Request:** + +- **Content-Type:** application/json +- **Schema:** ``DocumentSearch`` +- **Required:** Yes + +**Request Body Parameters:** + +- ``content`` (Type: string) - The content or keywords to search for within documents. +- ``num_docs`` (Type: integer, Default: 4) - The maximum number of documents to return in the response. + +**Responses:** + +- **200 - Successful Response** + + - Description: Documents matching the search criteria were found. + - Response Body: An object containing the search results. + +- **422 - Validation Error** + + - Description: There was a validation error with the request. + - Response Body: Details of the validation error. + + +# Running the chain server +If the web frontend needs to be stood up manually for development purposes, run the following commands: + +- Build the web UI container from source +``` + cd deploy/ + source compose.env + docker compose build query +``` +- Run the container which will start the server +``` + docker compose up query +``` + +- Open the swagger URL at ``http://host-ip:8081`` to try out the exposed endpoints. \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/docs/configuration.md b/RetrievalAugmentedGeneration/docs/configuration.md new file mode 100644 index 000000000..1361e7c72 --- /dev/null +++ b/RetrievalAugmentedGeneration/docs/configuration.md @@ -0,0 +1,74 @@ +## Configuration Guide + +### Chain Server Configuration + +In this section, we explore the configurations for the [Chain Server](./chat_server.md). Chain server interaction with other components can be controlled by config. Chain Server interacts with components such as the `milvus` vector store and `triton` server, which hosts the Large Language Model (LLM). Additionally, we'll delve into customization options to fine-tune the behavior of the query server. These options include settings for the embedding model, chunk size, and prompts for generating responses. + +You can refer to [sample config](../deploy/config.yaml) to see the structure. + +#### Milvus Configuration +`Milvus` serves as a vector database for storing embeddings. + + url: Configure the HTTP URI where the Milvus server is hosted. + +#### Triton Configuration +LLM Inference server hosts the Large Language Model (LLM) with triton backend. + + server_url: Specify the url of the LLM Inference Server. + + model_name: Provide the name of the model hosted on the Triton server. + Note: Changing the value of this field may need code changes. + +#### Text Splitter Configuration +This section covers the settings for the Text Splitter component. + + chunk_size: Define the size at which text should be split before being converted into embeddings. + + chunk_overlap: Specify the overlap between two consecutive text chunks to prevent loss of context. + +#### Embeddings Configuration +The Embeddings section contains information required for generating embeddings. + + model_name: Indicate the name of the model used to generate embeddings. + Note: Note that this may also necessitate changes in the model's dimensions, which can be adjusted in the chain_server/utils.py file. + +#### Prompts Configuration +Customize prompts used for generating responses. + + chat_template: The chat prompt template guides the model to generate responses for queries. + rag_template: The RAG prompt Template instructs the model to generate responses while leveraging a knowledge base. + +You set path to use this config file to be used by chain server using enviornment variable `APP_CONFIG_FILE`. You can do the same in [compose.env](../deploy/compose.env) and source the file. + +### Configuring docker compose file +In this section, we will look into the environment variables and parameters that can be configured within the [Docker Compose](../deploy/docker-compose.yaml) YAML file. Our system comprises multiple microservices that interact harmoniously to generate responses. These microservices include LLM Inference Server, Jupyter Server, Milvus, Query/chain server, and Frontend. + +#### Triton Configuration +The LLM Inference Server is used for hosting the Large Language Model (LLM) with triton backend. You can configure the model information using the [compose.env](../deploy/compose.env) file or by setting the corresponding environment variables. Here is a list of environment variables utilized by the llm inference server: + + MODEL_DIRECTORY: Specifies the path to the model directory where model checkpoints are stored. + MODEL_ARCHITECTURE: Defines the architecture of the model used for deployment. + MODEL_MAX_INPUT_LENGTH: Maximum allowed input length, with a default value of 3000. + MODEL_MAX_OUTPUT_LENGTH: Maximum allowed output lenght, with a default of 512. + INFERENCE_GPU_COUNT: Specifies the GPUs to be used by Triton for model deployment, with the default setting being "all." + +#### Jupyter Server +This server hosts jupyter lab server. This contains notebook explaining the flow of chain server. + +#### Milvus +Milvus serves as a GPU-accelerated vector store database, where we store embeddings generated by the knowledge base. + +#### Query/Chain Server +The Query service is the core component responsible for interacting with the llm inference server and the Milvus server to obtain responses. The environment variables utilized by this container are described as follows: + + APP_MILVUS_URL: Specifies the URL where the Milvus server is hosted. + APP_TRITON_SERVERURL: Specifies the URL where the Triton server is hosted. + APP_TRITON_MODELNAME: The model name used by the Triton server. + APP_CONFIG_FILE: Provides the path to the configuration file used by the Chain Server or this container. Defaults to /dev/null + +#### Frontend +The Frontend component is the UI server that interacts with the Query/Chain Server to retrieve responses and provide UI interface to ingest documents. The following environment variables are used by the frontend: + + APP_SERVERURL: Indicates the URL where the Query/Chain Server is hosted. + APP_SERVERPORT: Specifies the port on which the Query/Chain Server operates. + APP_MODELNAME: Neme of the Large Language Model utilized for deployment. This information is for display purposes only and does not affect the inference process. \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/docs/frontend.md b/RetrievalAugmentedGeneration/docs/frontend.md new file mode 100644 index 000000000..4697155d9 --- /dev/null +++ b/RetrievalAugmentedGeneration/docs/frontend.md @@ -0,0 +1,33 @@ +# Web Frontend +------------ +The web frontend provides a UI on top of the [RAG chat server APIs](./chat_server.md). +- Users can chat with the LLM and see responses streamed back. +- By selecting “Use knowledge base,” the chatbot returns responses augmented with the data that’s been stored in the vector database. +- To store content in the vector database, change the window to “Knowledge Base” in the upper right corner and upload documents. + +![Diagram](./../images/image4.jpg) + +# Frontend structure + +At its core, llm-playground is a FastAPI server written in Python. This FastAPI server hosts two [Gradio](https://www.gradio.app/) applications, one for conversing with the model and another for uploading documents. These Gradio pages are wrapped in a static frame created with the Kaizen UI React+Next.js framework and compiled down to static pages. Iframes are used to mount the Gradio applications into the outer frame. In case you are interested in customizing this sample frontend follow [this section to understand the basic workflow.](../frontend/README.md) + +# Running the web UI +If the web frontend needs to be stood up manually for development purposes, run the following commands: + +- Build the web UI container from source +``` + cd deploy/ + source compose.env + docker compose build frontend +``` +- Run the container which will start the server +``` + docker compose up frontend +``` + +- Open the web application at ``http://host-ip:8090`` + +Note: +- If multiple pdf files are being uploaded the expected time of completion as shown in the UI may not be correct. + + diff --git a/RetrievalAugmentedGeneration/docs/jupyter_server.md b/RetrievalAugmentedGeneration/docs/jupyter_server.md new file mode 100644 index 000000000..a676e7067 --- /dev/null +++ b/RetrievalAugmentedGeneration/docs/jupyter_server.md @@ -0,0 +1,40 @@ +# Jupyter Notebooks +For development and experimentation purposes, the Jupyter notebooks provide guidance to building knowledge augmented chatbots. + +The following Jupyter notebooks are provided with the AI workflow: + +1. [**LLM Streaming Client**](../notebooks/01-llm-streaming-client.ipynb) + +This notebook demonstrates how to use a client to stream responses from an LLM deployed to NVIDIA Triton Inference Server with NVIDIA TensorRT-LLM (TRT-LLM). This deployment format optimizes the model for low latency and high throughput inference. + +2. [**Document Question-Answering with LangChain**](../notebooks/02_langchain_simple.ipynb) + +This notebook demonstrates how to use LangChain to build a chatbot that references a custom knowledge-base. LangChain provides a simple framework for connecting LLMs to your own data sources. It shows how to integrate a TensorRT-LLM to LangChain using a custom wrapper. + +3. [**Document Question-Answering with LlamaIndex**](../notebooks/03_llama_index_simple.ipynb) + +This notebook demonstrates how to use LlamaIndex to build a chatbot that references a custom knowledge-base. It contains the same functionality as this notebook before, but uses some LlamaIndex components instead of LangChain components. It also shows how the two frameworks can be used together. + +4. [**Advanced Document Question-Answering with LlamaIndex**](../notebooks/04_llamaindex_hier_node_parser.ipynb) + +This notebook demonstrates how to use LlamaIndex to build a more complex retrieval for a chatbot. The retrieval method shown in this notebook works well for code documentation; it retrieves more contiguous document blocks that preserve both code snippets and explanations of code. + +5. [**Interact with REST FastAPI Server**](../notebooks/05_dataloader.ipynb) + +This notebook demonstrates how to use the REST FastAPI server to upload the knowledge base and then ask a question without and with the knowledge base. + +# Running the notebooks +If a JupyterLab server needs to be compiled and stood up manually for development purposes, run the following commands: +- Build the container +``` + cd deploy/ + source compose.env + docker compose build jupyter-server +``` +- Run the container which starts the notebook server +``` + docker compose up jupyter-server +``` +- Using a web browser, type in the following URL to access the notebooks. + + ``http://host-ip:8888`` \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/docs/llm_inference_server.md b/RetrievalAugmentedGeneration/docs/llm_inference_server.md new file mode 100644 index 000000000..0e60251d0 --- /dev/null +++ b/RetrievalAugmentedGeneration/docs/llm_inference_server.md @@ -0,0 +1,31 @@ +# NeMo Framework Inference Server + +We use [NeMo Framework Inference Server](https://docs.nvidia.com/nemo-framework/user-guide/latest/deployingthenemoframeworkmodel.html) container which help us to create optimized LLM using TensorRT LLM and deploy using NVIDIA Triton Server for high-performance, cost-effective, and low-latency inference. Within this workflow, We use Llama2 models and LLM Inference Server container contains modules and script required for TRT-LLM conversion of the Llama2 models and deployment using NVIDIA Triton Server. + + +# Running the LLM Inference Server + +To convert Llama2 to TensorRT and host it on Triton Model Server for development purposes, run the following commands: + +- Download Llama2 Chat Model Weights from [Meta](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) or [HuggingFace](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf/). You can check [support matrix](../docs/support_matrix.md) for GPU requirements for the deployment. + +- Update [compose.env](../deploy/compose.env) with MODEL_DIRECTORY as Llama2 model downloaded path and other model parameters as needed. + +- Build the llm inference server container from source +``` + cd deploy/ + source compose.env + docker compose build triton +``` +- Run the container which will start the triton server with TRT-LLM optimized Llama2 model +``` + docker compose up triton +``` + +- Once the optimized Llama2 is deployed in Triton Server, clients can send HTTP/REST or gRPC requests directly to Triton Server. Example implmentation of the client can be found [here](../llm-inference-server/model_server_client/trt_llm.py). + + + +**Note for checkpoint downloaded using Meta**: + + When downloading model weights from Meta, you can follow the instructions up to the point of downloading the models using ``download.sh``. Meta will download two additional files, namely tokenizer.model and tokenizer_checklist.chk, outside of the model checkpoint directory. Ensure that you copy these files into the same directory as the model checkpoint directory. \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/docs/support_matrix.md b/RetrievalAugmentedGeneration/docs/support_matrix.md new file mode 100644 index 000000000..96bc81dca --- /dev/null +++ b/RetrievalAugmentedGeneration/docs/support_matrix.md @@ -0,0 +1,30 @@ +# GPU Requirements +Large Language Models are a heavily GPU-limited workflow. All LLMs are defined by the number of billions of parameters that make up their networks. For this workflow, we are focusing on the Llama 2 Chat models from Meta. These models come in three different sizes: 7B, 13B, and 70B. All three models perform very well, but the 13B model is a good balance of performance and GPU Memory utilization. + +Llama2-7B-Chat requires about 30GB of GPU memory. + +Llama2-13B-Chat requires about 50GB of GPU memory. + +Llama2-70B-Chat requires about 320GB of GPU memory. + +These resources can be provided by multiple GPUs on the same machine. + +To perform retrieval augmentation, another model must be hosted. This model is much smaller and is called an embedding model. It is responsible for converting a sequence of words to a representation in the form of a vector of numbers. This model requires an additional 2GB of GPU memory. + +In this workflow, Milvus was selected as the Vector Database. It was selected because Milvus has implemented the NVIDIA RAFT libraries that enable GPU acceleration of vector searches. For the Milvus database, allow an additional 4GB of GPU Memory. + +# CPU and Memory Requirements +For development purposes, we recommend that at least 10 CPU Cores and 64 GB of RAM are available. + +# Storage Requirements +There are two main drivers for storage consumption in retrieval augmented generation. The model weights and the documents in the vector database. The file size of the model varies on how large the model is. + +Llama2-7B-Chat requires about 30GB of storage. + +Llama2-13B-Chat requires about 50GB of storage. + +Llama2-70B-Chat requires about 150GB of storage. + +The file space needed for the vector database varies by how many documents it will store. For development purposes, allocating 10 GB is plenty. + +You will need additionally about 60GB of storage for docker images. diff --git a/RetrievalAugmentedGeneration/frontend/Dockerfile b/RetrievalAugmentedGeneration/frontend/Dockerfile new file mode 100644 index 000000000..be5975b67 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/Dockerfile @@ -0,0 +1,12 @@ +FROM docker.io/library/python:3.11-slim + +COPY frontend /app/frontend +COPY requirements.txt /app +RUN apt-get update; \ + apt-get upgrade -y; \ + python3 -m pip --no-cache-dir install -r /app/requirements.txt; \ + apt-get clean +USER 1001 + +WORKDIR /app +ENTRYPOINT ["python3", "-m", "frontend"] diff --git a/RetrievalAugmentedGeneration/frontend/README.md b/RetrievalAugmentedGeneration/frontend/README.md new file mode 100644 index 000000000..9576536fd --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/README.md @@ -0,0 +1,38 @@ +# Customizing the frontend + +This section chalks out steps needed for updating the static HTML and JavaScript based pages. + +The Kaizen UI frame is stored in the folder [fronted_js/](frontend_js). To make modifications to this content, you will need an up-to-date version of NodeJS and NPM. You can find instructions on the install of those components in the [official NodeJS documentation](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm). + +> **NOTE:** If you are using Ubuntu, do not use NodeJS from the Ubuntu repositories. It is too old to work with NextJS. + +Once you have a working install of NodeJS, open a shell in the [fronted_js/](frontend_js) and run the following command: + +```bash +npm install +``` + +This will install NextJS and all of the necesary Kaizen UI and React components into the [fronted_js/node_modules/](fronted_js/node_modules) folder. + +The source code for the outer frame is in the [fronted_js/src/](fronted_js/src) directory. You can make modifications to this and run a development server to see your changes by running the following command: + +```bash +npm run dev +``` + +> **NOTE:** The development server will not be able to run and mount the Gradio applications. Instead, you will see the outer frame mounted in itself with a 404 message. This is normal. + +When your changes are complete, you can compile the NextJS code into static HTML and JavaScript with the following command: + +```bash +npm run build +``` + +This will place the compiled static code in the [fronted_js/out/](fronted_js/out) directory. This must be copied in the the FastAPI static directory. However, be sure to remove the existing static content in FastAPI. Open a shell in this projects root folder and run the following commands: + +```bash +rm -rf frontend/static/* +cp -rav frontend_js/out/* frontend/static/ +``` + +The next time you launch the FastAPI server, it will now have your updated outer frame. \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/frontend/frontend/__init__.py b/RetrievalAugmentedGeneration/frontend/frontend/__init__.py new file mode 100644 index 000000000..631a8847f --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/__init__.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Document Retrieval Service. + +Handle document ingestion and retrieval from a VectorDB. +""" + +import logging +import os +import sys +import typing + +if typing.TYPE_CHECKING: + from frontend.api import APIServer + + +_LOG_FMT = f"[{os.getpid()}] %(asctime)15s [%(levelname)7s] - %(name)s - %(message)s" +_LOG_DATE_FMT = "%b %d %H:%M:%S" +_LOGGER = logging.getLogger(__name__) + + +def bootstrap_logging(verbosity: int = 0) -> None: + """Configure Python's logger according to the given verbosity level. + + :param verbosity: The desired verbosity level. Must be one of 0, 1, or 2. + :type verbosity: typing.Literal[0, 1, 2] + """ + # determine log level + verbosity = min(2, max(0, verbosity)) # limit verbosity to 0-2 + log_level = [logging.WARN, logging.INFO, logging.DEBUG][verbosity] + + # configure python's logger + logging.basicConfig(format=_LOG_FMT, datefmt=_LOG_DATE_FMT, level=log_level) + # update existing loggers + _LOGGER.setLevel(log_level) + for logger in [ + __name__, + "uvicorn", + "uvicorn.access", + "uvicorn.error", + ]: + for handler in logging.getLogger(logger).handlers: + handler.setFormatter(logging.Formatter(fmt=_LOG_FMT, datefmt=_LOG_DATE_FMT)) + + +def main() -> "APIServer": + """Bootstrap and Execute the application. + + :returns: 0 if the application completed successfully, 1 if an error occurred. + :rtype: Literal[0,1] + """ + # boostrap python loggers + verbosity = int(os.environ.get("APP_VERBOSITY", "1")) + bootstrap_logging(verbosity) + + # load the application libraries + # pylint: disable=import-outside-toplevel; this is intentional to allow for the environment to be configured before + # any of the application libraries are loaded. + from frontend import api, chat_client, configuration + + # load config + config_file = os.environ.get("APP_CONFIG_FILE", "/dev/null") + _LOGGER.info("Loading application configuration.") + config = configuration.AppConfig.from_file(config_file) + if not config: + sys.exit(1) + _LOGGER.info("Configuration: \n%s", config.to_yaml()) + + # connect to other services + client = chat_client.ChatClient( + f"{config.server_url}:{config.server_port}", config.model_name + ) + + # create api server + _LOGGER.info("Instantiating the API Server.") + server = api.APIServer(client) + server.configure_routes() + + # run until complete + _LOGGER.info("Starting the API Server.") + return server diff --git a/RetrievalAugmentedGeneration/frontend/frontend/__main__.py b/RetrievalAugmentedGeneration/frontend/frontend/__main__.py new file mode 100644 index 000000000..cd2fa34a3 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/__main__.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Entrypoint for the Conversation GUI. + +The functions in this module are responsible for bootstrapping then executing the Conversation GUI server. +""" + +import argparse +import os +import sys + +import uvicorn + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments for the program. + + :returns: A namespace containing the parsed arguments. + :rtype: argparse.Namespace + """ + parser = argparse.ArgumentParser(description="Document Retrieval Service") + + parser.add_argument( + "--help-config", + action="store_true", + default=False, + help="show the configuration help text", + ) + + parser.add_argument( + "-c", + "--config", + metavar="CONFIGURATION_FILE", + default="/dev/null", + help="path to the configuration file (json or yaml)", + ) + parser.add_argument( + "-v", + "--verbose", + action="count", + default=1, + help="increase output verbosity", + ) + parser.add_argument( + "-q", + "--quiet", + action="count", + default=0, + help="decrease output verbosity", + ) + + parser.add_argument( + "--host", + metavar="HOSTNAME", + type=str, + default="0.0.0.0", # nosec # this is intentional + help="Bind socket to this host.", + ) + parser.add_argument( + "--port", + metavar="PORT_NUM", + type=int, + default=8080, + help="Bind socket to this port.", + ) + parser.add_argument( + "--workers", + metavar="NUM_WORKERS", + type=int, + default=1, + help="Number of worker processes.", + ) + parser.add_argument( + "--ssl-keyfile", metavar="SSL_KEY", type=str, default=None, help="SSL key file" + ) + parser.add_argument( + "--ssl-certfile", + metavar="SSL_CERT", + type=str, + default=None, + help="SSL certificate file", + ) + + cliargs = parser.parse_args() + if cliargs.help_config: + # pylint: disable=import-outside-toplevel; this is intentional to allow for the environment to be configured + # before any of the application libraries are loaded. + from frontend.configuration import AppConfig + + sys.stdout.write("\nconfiguration file format:\n") + AppConfig.print_help(sys.stdout.write) + sys.exit(0) + + return cliargs + + +if __name__ == "__main__": + args = parse_args() + os.environ["APP_VERBOSITY"] = f"{args.verbose - args.quiet}" + os.environ["APP_CONFIG_FILE"] = args.config + uvicorn.run( + "frontend:main", + factory=True, + host=args.host, + port=args.port, + workers=args.workers, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) diff --git a/RetrievalAugmentedGeneration/frontend/frontend/api.py b/RetrievalAugmentedGeneration/frontend/frontend/api.py new file mode 100644 index 000000000..25df0ad51 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/api.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the Server that will host the frontend and API.""" +import os + +import gradio as gr +from fastapi import FastAPI +from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles +from frontend.chat_client import ChatClient + +from frontend import pages + +STATIC_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "static") + + +class APIServer(FastAPI): + """A class that hosts the service api. + + :cvar title: The title of the server. + :type title: str + :cvar desc: A description of the server. + :type desc: str + """ + + title = "Chat" + desc = "This service provides a sample conversation frontend flow." + + def __init__(self, client: ChatClient) -> None: + """Initialize the API server.""" + self._client = client + super().__init__(title=self.title, description=self.desc) + + def configure_routes(self) -> None: + """Configure the routes in the API Server.""" + _ = gr.mount_gradio_app( + self, + blocks=pages.converse.build_page(self._client), + path=f"/content{pages.converse.PATH}", + ) + _ = gr.mount_gradio_app( + self, + blocks=pages.kb.build_page(self._client), + path=f"/content{pages.kb.PATH}", + ) + + @self.get("/") + async def root_redirect() -> FileResponse: + return FileResponse(os.path.join(STATIC_DIR, "converse.html")) + + @self.get("/converse") + async def converse_redirect() -> FileResponse: + return FileResponse(os.path.join(STATIC_DIR, "converse.html")) + + @self.get("/kb") + async def kb_redirect() -> FileResponse: + return FileResponse(os.path.join(STATIC_DIR, "kb.html")) + + self.mount("/", StaticFiles(directory=STATIC_DIR, html=True)) diff --git a/RetrievalAugmentedGeneration/frontend/frontend/assets/__init__.py b/RetrievalAugmentedGeneration/frontend/frontend/assets/__init__.py new file mode 100644 index 000000000..5e26052c3 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/assets/__init__.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains theming assets.""" +import os.path +from typing import Tuple + +import gradio as gr + +_ASSET_DIR = os.path.dirname(__file__) + + +def load_theme(name: str) -> Tuple[gr.Theme, str]: + """Load a pre-defined frontend theme. + + :param name: The name of the theme to load. + :type name: str + :returns: A tuple containing the Gradio theme and custom CSS. + :rtype: Tuple[gr.Theme, str] + """ + theme_json_path = os.path.join(_ASSET_DIR, f"{name}-theme.json") + theme_css_path = os.path.join(_ASSET_DIR, f"{name}-theme.css") + return ( + gr.themes.Default().load(theme_json_path), + open(theme_css_path, encoding="UTF-8").read(), + ) diff --git a/RetrievalAugmentedGeneration/frontend/frontend/assets/kaizen-theme.css b/RetrievalAugmentedGeneration/frontend/frontend/assets/kaizen-theme.css new file mode 100644 index 000000000..04e930498 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/assets/kaizen-theme.css @@ -0,0 +1,13 @@ +.tabitem { + background-color: var(--block-background-fill); + } + + .gradio-container { + /* This needs to be !important, otherwise the breakpoint override the container being full width */ + max-width: 100% !important; + padding: 10px !important; + } + + footer { + visibility: hidden; + } diff --git a/RetrievalAugmentedGeneration/frontend/frontend/assets/kaizen-theme.json b/RetrievalAugmentedGeneration/frontend/frontend/assets/kaizen-theme.json new file mode 100644 index 000000000..a3218660b --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/assets/kaizen-theme.json @@ -0,0 +1,336 @@ +{ + "theme": { + "_font": [ + { + "__gradio_font__": true, + "name": "NVIDIA Sans", + "class": "font" + }, + { + "__gradio_font__": true, + "name": "ui-sans-serif", + "class": "font" + }, + { + "__gradio_font__": true, + "name": "system-ui", + "class": "font" + }, + { + "__gradio_font__": true, + "name": "sans-serif", + "class": "font" + } + ], + "_font_mono": [ + { + "__gradio_font__": true, + "name": "JetBrains Mono", + "class": "google" + }, + { + "__gradio_font__": true, + "name": "ui-monospace", + "class": "font" + }, + { + "__gradio_font__": true, + "name": "Consolas", + "class": "font" + }, + { + "__gradio_font__": true, + "name": "monospace", + "class": "font" + } + ], + "_stylesheets": [ + "https://fonts.googleapis.com/css2?family=JetBrains+Mono&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap", + "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Lt.woff2", + "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_LtIt.woff2", + "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Rg.woff2", + "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_It.woff2", + "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Md.woff2", + "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_MdIt.woff2", + "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Bd.woff2", + "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_BdIt.woff2" + ], + "background_fill_primary": "#ffffff", + "background_fill_primary_dark": "#292929", + "background_fill_secondary": "*neutral_50", + "background_fill_secondary_dark": "*neutral_900", + "block_background_fill": "#ffffff", + "block_background_fill_dark": "#292929", + "block_border_color": "#d8d8d8", + "block_border_color_dark": "*border_color_primary", + "block_border_width": "1px", + "block_info_text_color": "*body_text_color_subdued", + "block_info_text_color_dark": "*body_text_color_subdued", + "block_info_text_size": "*text_sm", + "block_info_text_weight": "400", + "block_label_background_fill": "#e4fabe", + "block_label_background_fill_dark": "#e4fabe", + "block_label_border_color": "#e4fabe", + "block_label_border_color_dark": "#e4fabe", + "block_label_border_width": "1px", + "block_label_margin": "0", + "block_label_padding": "*spacing_sm *spacing_lg", + "block_label_radius": "calc(*radius_lg - 1px) 0 calc(*radius_lg - 1px) 0", + "block_label_right_radius": "0 calc(*radius_lg - 1px) 0 calc(*radius_lg - 1px)", + "block_label_shadow": "*block_shadow", + "block_label_text_color": "#4d6721", + "block_label_text_color_dark": "#4d6721", + "block_label_text_size": "*text_sm", + "block_label_text_weight": "400", + "block_padding": "*spacing_xl calc(*spacing_xl + 2px)", + "block_radius": "*radius_lg", + "block_shadow": "*shadow_drop", + "block_title_background_fill": "none", + "block_title_border_color": "none", + "block_title_border_width": "0px", + "block_title_padding": "0", + "block_title_radius": "none", + "block_title_text_color": "*neutral_500", + "block_title_text_color_dark": "*neutral_200", + "block_title_text_size": "*text_md", + "block_title_text_weight": "500", + "body_background_fill": "#f2f2f2", + "body_background_fill_dark": "#202020", + "body_text_color": "#202020", + "body_text_color_dark": "#f2f2f2", + "body_text_color_subdued": "*neutral_400", + "body_text_color_subdued_dark": "*neutral_400", + "body_text_size": "*text_md", + "body_text_weight": "400", + "border_color_accent": "*primary_300", + "border_color_accent_dark": "*neutral_600", + "border_color_primary": "#d8d8d8", + "border_color_primary_dark": "#343434", + "button_border_width": "1px", + "button_border_width_dark": "1px", + "button_cancel_background_fill": "#dc3528", + "button_cancel_background_fill_dark": "#dc3528", + "button_cancel_background_fill_hover": "#b6251b", + "button_cancel_background_fill_hover_dark": "#b6251b", + "button_cancel_border_color": "#dc3528", + "button_cancel_border_color_dark": "#dc3528", + "button_cancel_border_color_hover": "#b6251b", + "button_cancel_border_color_hover_dark": "#b6251b", + "button_cancel_text_color": "#ffffff", + "button_cancel_text_color_dark": "#ffffff", + "button_cancel_text_color_hover": "#ffffff", + "button_cancel_text_color_hover_dark": "#ffffff", + "button_large_padding": "*spacing_lg calc(2 * *spacing_lg)", + "button_large_radius": "*radius_lg", + "button_large_text_size": "*text_lg", + "button_large_text_weight": "500", + "button_primary_background_fill": "#76b900", + "button_primary_background_fill_dark": "#76b900", + "button_primary_background_fill_hover": "#659f00", + "button_primary_background_fill_hover_dark": "#659f00", + "button_primary_border_color": "#76b900", + "button_primary_border_color_dark": "#76b900", + "button_primary_border_color_hover": "#659f00", + "button_primary_border_color_hover_dark": "#659f00", + "button_primary_text_color": "#202020", + "button_primary_text_color_dark": "#202020", + "button_primary_text_color_hover": "#202020", + "button_primary_text_color_hover_dark": "#202020", + "button_secondary_background_fill": "#ffffff", + "button_secondary_background_fill_dark": "#292929", + "button_secondary_background_fill_hover": "#e2e2e2", + "button_secondary_background_fill_hover_dark": "#202020", + "button_secondary_border_color": "#5e5e5e", + "button_secondary_border_color_dark": "#c6c6c6", + "button_secondary_border_color_hover": "#5e5e5e", + "button_secondary_border_color_hover_dark": "#c6c6c6", + "button_secondary_text_color": "#5e5e5e", + "button_secondary_text_color_dark": "#e2e2e2", + "button_secondary_text_color_hover": "#343434", + "button_secondary_text_color_hover_dark": "#ffffff", + "button_shadow": "*shadow_drop", + "button_shadow_active": "*shadow_inset", + "button_shadow_hover": "*shadow_drop_lg", + "button_small_padding": "*spacing_sm calc(2 * *spacing_sm)", + "button_small_radius": "*radius_lg", + "button_small_text_size": "*text_md", + "button_small_text_weight": "400", + "button_transition": "none", + "chatbot_code_background_color": "*neutral_100", + "chatbot_code_background_color_dark": "*neutral_800", + "checkbox_background_color": "*background_fill_primary", + "checkbox_background_color_dark": "*neutral_800", + "checkbox_background_color_focus": "*checkbox_background_color", + "checkbox_background_color_focus_dark": "*checkbox_background_color", + "checkbox_background_color_hover": "*checkbox_background_color", + "checkbox_background_color_hover_dark": "*checkbox_background_color", + "checkbox_background_color_selected": "#659f00", + "checkbox_background_color_selected_dark": "#659f00", + "checkbox_border_color": "*neutral_300", + "checkbox_border_color_dark": "*neutral_700", + "checkbox_border_color_focus": "*secondary_500", + "checkbox_border_color_focus_dark": "*secondary_500", + "checkbox_border_color_hover": "*neutral_300", + "checkbox_border_color_hover_dark": "*neutral_600", + "checkbox_border_color_selected": "#659f00", + "checkbox_border_color_selected_dark": "#659f00", + "checkbox_border_radius": "*radius_sm", + "checkbox_border_width": "2px", + "checkbox_border_width_dark": "*input_border_width", + "checkbox_check": "url(\"data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e\")", + "checkbox_label_background_fill": "#ffffff", + "checkbox_label_background_fill_dark": "#292929", + "checkbox_label_background_fill_hover": "#ffffff", + "checkbox_label_background_fill_hover_dark": "#292929", + "checkbox_label_background_fill_selected": "*checkbox_label_background_fill", + "checkbox_label_background_fill_selected_dark": "*checkbox_label_background_fill", + "checkbox_label_border_color": "#ffffff", + "checkbox_label_border_color_dark": "#292929", + "checkbox_label_border_color_hover": "*checkbox_label_border_color", + "checkbox_label_border_color_hover_dark": "*checkbox_label_border_color", + "checkbox_label_border_width": "0", + "checkbox_label_border_width_dark": "*input_border_width", + "checkbox_label_gap": "16px", + "checkbox_label_padding": "", + "checkbox_label_shadow": "none", + "checkbox_label_text_color": "*body_text_color", + "checkbox_label_text_color_dark": "*body_text_color", + "checkbox_label_text_color_selected": "*checkbox_label_text_color", + "checkbox_label_text_color_selected_dark": "*checkbox_label_text_color", + "checkbox_label_text_size": "*text_md", + "checkbox_label_text_weight": "400", + "checkbox_shadow": "*input_shadow", + "color_accent": "*primary_500", + "color_accent_soft": "*primary_50", + "color_accent_soft_dark": "*neutral_700", + "container_radius": "*radius_lg", + "embed_radius": "*radius_lg", + "error_background_fill": "#fef2f2", + "error_background_fill_dark": "*neutral_900", + "error_border_color": "#fee2e2", + "error_border_color_dark": "#ef4444", + "error_border_width": "1px", + "error_icon_color": "#b91c1c", + "error_icon_color_dark": "#ef4444", + "error_text_color": "#b91c1c", + "error_text_color_dark": "#fef2f2", + "font": "'NVIDIA Sans', 'ui-sans-serif', 'system-ui', sans-serif", + "font_mono": "'JetBrains Mono', 'ui-monospace', 'Consolas', monospace", + "form_gap_width": "1px", + "input_background_fill": "white", + "input_background_fill_dark": "*neutral_800", + "input_background_fill_focus": "*secondary_500", + "input_background_fill_focus_dark": "*secondary_600", + "input_background_fill_hover": "*input_background_fill", + "input_background_fill_hover_dark": "*input_background_fill", + "input_border_color": "#d8d8d8", + "input_border_color_dark": "#343434", + "input_border_color_focus": "*secondary_300", + "input_border_color_focus_dark": "*neutral_700", + "input_border_color_hover": "*input_border_color", + "input_border_color_hover_dark": "*input_border_color", + "input_border_width": "2px", + "input_padding": "*spacing_xl", + "input_placeholder_color": "*neutral_400", + "input_placeholder_color_dark": "*neutral_500", + "input_radius": "*radius_lg", + "input_shadow": "0 0 0 *shadow_spread transparent, *shadow_inset", + "input_shadow_focus": "0 0 0 *shadow_spread *secondary_50, *shadow_inset", + "input_shadow_focus_dark": "0 0 0 *shadow_spread *neutral_700, *shadow_inset", + "input_text_size": "*text_md", + "input_text_weight": "400", + "layout_gap": "*spacing_xxl", + "link_text_color": "*secondary_600", + "link_text_color_active": "*secondary_600", + "link_text_color_active_dark": "*secondary_500", + "link_text_color_dark": "*secondary_500", + "link_text_color_hover": "*secondary_700", + "link_text_color_hover_dark": "*secondary_400", + "link_text_color_visited": "*secondary_500", + "link_text_color_visited_dark": "*secondary_600", + "loader_color": "*color_accent", + "name": "default", + "neutral_100": "#e2e2e2", + "neutral_200": "#d8d8d8", + "neutral_300": "#c6c6c6", + "neutral_400": "#8f8f8f", + "neutral_50": "#f2f2f2", + "neutral_500": "#767676", + "neutral_600": "#5e5e5e", + "neutral_700": "#343434", + "neutral_800": "#292929", + "neutral_900": "#202020", + "neutral_950": "#121212", + "panel_background_fill": "*background_fill_secondary", + "panel_background_fill_dark": "*background_fill_secondary", + "panel_border_color": "*border_color_primary", + "panel_border_color_dark": "*border_color_primary", + "panel_border_width": "0", + "primary_100": "#caf087", + "primary_200": "#b6e95d", + "primary_300": "#9fd73d", + "primary_400": "#76b900", + "primary_50": "#e4fabe", + "primary_500": "#659f00", + "primary_600": "#538300", + "primary_700": "#4d6721", + "primary_800": "#253a00", + "primary_900": "#1d2e00", + "primary_950": "#172400", + "prose_header_text_weight": "600", + "prose_text_size": "*text_md", + "prose_text_weight": "400", + "radio_circle": "url(\"data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e\")", + "radius_lg": "0px", + "radius_md": "0px", + "radius_sm": "0px", + "radius_xl": "0px", + "radius_xs": "0px", + "radius_xxl": "0px", + "radius_xxs": "0px", + "secondary_100": "#cde6fa", + "secondary_200": "#badef8", + "secondary_300": "#9accf2", + "secondary_400": "#3a96d9", + "secondary_50": "#e9f4fb", + "secondary_500": "#2378ca", + "secondary_600": "#2a63ba", + "secondary_700": "#013076", + "secondary_800": "#00265e", + "secondary_900": "#001e4b", + "secondary_950": "#00112c", + "section_header_text_size": "*text_md", + "section_header_text_weight": "500", + "shadow_drop": "none", + "shadow_drop_lg": "0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1)", + "shadow_inset": "rgba(0,0,0,0.05) 0px 2px 4px 0px inset", + "shadow_spread": "3px", + "shadow_spread_dark": "1px", + "slider_color": "#9fd73d", + "spacing_lg": "8px", + "spacing_md": "6px", + "spacing_sm": "4px", + "spacing_xl": "10px", + "spacing_xs": "2px", + "spacing_xxl": "16px", + "spacing_xxs": "1px", + "stat_background_fill": "linear-gradient(to right, *primary_400, *primary_200)", + "stat_background_fill_dark": "linear-gradient(to right, *primary_400, *primary_600)", + "table_border_color": "*neutral_300", + "table_border_color_dark": "*neutral_700", + "table_even_background_fill": "white", + "table_even_background_fill_dark": "*neutral_950", + "table_odd_background_fill": "*neutral_50", + "table_odd_background_fill_dark": "*neutral_900", + "table_radius": "*radius_lg", + "table_row_focus": "*color_accent_soft", + "table_row_focus_dark": "*color_accent_soft", + "text_lg": "16px", + "text_md": "14px", + "text_sm": "12px", + "text_xl": "22px", + "text_xs": "10px", + "text_xxl": "26px", + "text_xxs": "9px" + } + } diff --git a/RetrievalAugmentedGeneration/frontend/frontend/chat_client.py b/RetrievalAugmentedGeneration/frontend/frontend/chat_client.py new file mode 100644 index 000000000..e2c5a938f --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/chat_client.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The API client for the langchain-esque service.""" +import logging +import mimetypes +import typing + +import requests + +_LOGGER = logging.getLogger(__name__) + + +class ChatClient: + """A client for connecting the the lanchain-esque service.""" + + def __init__(self, server_url: str, model_name: str) -> None: + """Initialize the client.""" + self.server_url = server_url + self._model_name = model_name + self.default_model = "llama2-7B-chat" + + @property + def model_name(self) -> str: + """Return the friendly model name.""" + return self._model_name + + def search( + self, prompt: str + ) -> typing.List[typing.Dict[str, typing.Union[str, float]]]: + """Search for relevant documents and return json data.""" + data = {"content": prompt, "num_docs": 4} + headers = {"accept": "application/json", "Content-Type": "application/json"} + url = f"{self.server_url}/documentSearch" + _LOGGER.debug( + "looking up documents - %s", str({"server_url": url, "post_data": data}) + ) + + with requests.post(url, headers=headers, json=data, timeout=30) as req: + response = req.json() + return typing.cast( + typing.List[typing.Dict[str, typing.Union[str, float]]], response + ) + + def predict( + self, query: str, use_knowledge_base: bool, num_tokens: int + ) -> typing.Generator[str, None, None]: + """Make a model prediction.""" + data = { + "question": query, + "context": "", + "use_knowledge_base": use_knowledge_base, + "num_tokens": num_tokens, + } + url = f"{self.server_url}/generate" + _LOGGER.debug( + "making inference request - %s", str({"server_url": url, "post_data": data}) + ) + + with requests.post(url, stream=True, json=data, timeout=10) as req: + for chunk in req.iter_content(16): + yield chunk.decode("UTF-8") + + def upload_documents(self, file_paths: typing.List[str]) -> None: + """Upload documents to the kb.""" + url = f"{self.server_url}/uploadDocument" + headers = { + "accept": "application/json", + } + + for fpath in file_paths: + mime_type, _ = mimetypes.guess_type(fpath) + # pylint: disable-next=consider-using-with # with pattern is not intuitive here + files = {"file": (fpath, open(fpath, "rb"), mime_type)} + + _LOGGER.debug( + "uploading file - %s", + str({"server_url": url, "file": fpath}), + ) + + _ = requests.post( + url, headers=headers, files=files, timeout=30 # type: ignore [arg-type] + ) diff --git a/RetrievalAugmentedGeneration/frontend/frontend/configuration.py b/RetrievalAugmentedGeneration/frontend/frontend/configuration.py new file mode 100644 index 000000000..56f2d3c26 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/configuration.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The definition of the application configuration.""" +from frontend.configuration_wizard import ConfigWizard, configclass, configfield + + +@configclass +class AppConfig(ConfigWizard): + """Configuration class for the application. + + :cvar triton: The configuration of the chat server. + :type triton: ChatConfig + :cvar model: The configuration of the model + :type triton: ModelConfig + """ + + server_url: str = configfield( + "serverUrl", + default="http://10.110.17.73", + help_txt="The location of the chat server.", + ) + server_port: str = configfield( + "serverPort", + default="8000", + help_txt="The port on which the chat server is listening for HTTP requests.", + ) + model_name: str = configfield( + "modelName", + default="llama2-7B-chat", + help_txt="The name of the hosted LLM model.", + ) diff --git a/RetrievalAugmentedGeneration/frontend/frontend/configuration_wizard.py b/RetrievalAugmentedGeneration/frontend/frontend/configuration_wizard.py new file mode 100644 index 000000000..d63d9e416 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/configuration_wizard.py @@ -0,0 +1,411 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A module containing utilities for defining application configuration. + +This module provides a configuration wizard class that can read configuration data from YAML, JSON, and environment +variables. The configuration wizard is based heavily off of the JSON and YAML wizards from the `dataclass-wizard` +Python package. That package is in-turn based heavily off of the built-in `dataclass` module. + +This module adds Environment Variable parsing to config file reading. +""" +# pylint: disable=too-many-lines; this file is meant to be portable between projects so everything is put into one file + +import json +import logging +import os +from dataclasses import _MISSING_TYPE, dataclass +from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple, Union + +import yaml +from dataclass_wizard import ( + JSONWizard, + LoadMeta, + YAMLWizard, + errors, + fromdict, + json_field, +) +from dataclass_wizard.models import JSONField +from dataclass_wizard.utils.string_conv import to_camel_case + +configclass = dataclass(frozen=True) +ENV_BASE = "APP" +_LOGGER = logging.getLogger(__name__) + + +def configfield( + name: str, *, env: bool = True, help_txt: str = "", **kwargs: Any +) -> JSONField: + """Create a data class field with the specified name in JSON format. + + :param name: The name of the field. + :type name: str + :param env: Whether this field should be configurable from an environment variable. + :type env: bool + :param help_txt: The description of this field that is used in help docs. + :type help_txt: str + :param **kwargs: Optional keyword arguments to customize the JSON field. More information here: + https://dataclass-wizard.readthedocs.io/en/latest/dataclass_wizard.html#dataclass_wizard.json_field + :type **kwargs: Any + :returns: A JSONField instance with the specified name and optional parameters. + :rtype: JSONField + + :raises TypeError: If the provided name is not a string. + """ + # sanitize specified name + if not isinstance(name, str): + raise TypeError("Provided name must be a string.") + json_name = to_camel_case(name) + + # update metadata + meta = kwargs.get("metadata", {}) + meta["env"] = env + meta["help"] = help_txt + kwargs["metadata"] = meta + + # create the data class field + field = json_field(json_name, **kwargs) + return field + + +class _Color: + """A collection of colors used when writing output to the shell.""" + + # pylint: disable=too-few-public-methods; this class does not require methods. + + PURPLE = "\033[95m" + BLUE = "\033[94m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + END = "\033[0m" + + +class ConfigWizard(JSONWizard, YAMLWizard): # type: ignore[misc] # dataclass-wizard doesn't provide stubs + """A configuration wizard class that can read configuration data from YAML, JSON, and environment variables.""" + + # pylint: disable=arguments-differ,arguments-renamed; this class intentionally reduces arguments for some methods. + + @classmethod + def print_help( + cls, + help_printer: Callable[[str], Any], + *, + env_parent: Optional[str] = None, + json_parent: Optional[Tuple[str, ...]] = None, + ) -> None: + """Print the help documentation for the application configuration with the provided `write` function. + + :param help_printer: The `write` function that will be used to output the data. + :param help_printer: Callable[[str], None] + :param env_parent: The name of the parent environment variable. Leave blank, used for recursion. + :type env_parent: Optional[str] + :param json_parent: The name of the parent JSON key. Leave blank, used for recursion. + :type json_parent: Optional[Tuple[str, ...]] + :returns: A list of tuples with one item per configuration value. Each item will have the environment variable + and a tuple to the path in configuration. + :rtype: List[Tuple[str, Tuple[str, ...]]] + """ + if not env_parent: + env_parent = "" + help_printer("---\n") + if not json_parent: + json_parent = () + + for ( + _, + val, + ) in ( + cls.__dataclass_fields__.items() # pylint: disable=no-member; false positive + ): # pylint: disable=no-member; member is added by dataclass. + jsonname = val.json.keys[0] + envname = jsonname.upper() + full_envname = f"{ENV_BASE}{env_parent}_{envname}" + is_embedded_config = hasattr(val.type, "envvars") + + # print the help data + indent = len(json_parent) * 2 + if is_embedded_config: + default = "" + elif not isinstance(val.default_factory, _MISSING_TYPE): + default = val.default_factory() + elif isinstance(val.default, _MISSING_TYPE): + default = "NO-DEFAULT-VALUE" + else: + default = val.default + help_printer( + f"{_Color.BOLD}{' ' * indent}{jsonname}:{_Color.END} {default}\n" + ) + + # print comments + if is_embedded_config: + indent += 2 + if val.metadata.get("help"): + help_printer(f"{' ' * indent}# {val.metadata['help']}\n") + if not is_embedded_config: + typestr = getattr(val.type, "__name__", None) or str(val.type).replace( + "typing.", "" + ) + help_printer(f"{' ' * indent}# Type: {typestr}\n") + if val.metadata.get("env", True): + help_printer(f"{' ' * indent}# ENV Variable: {full_envname}\n") + # if not is_embedded_config: + help_printer("\n") + + if is_embedded_config: + new_env_parent = f"{env_parent}_{envname}" + new_json_parent = json_parent + (jsonname,) + val.type.print_help( + help_printer, env_parent=new_env_parent, json_parent=new_json_parent + ) + + help_printer("\n") + + @classmethod + def envvars( + cls, + env_parent: Optional[str] = None, + json_parent: Optional[Tuple[str, ...]] = None, + ) -> List[Tuple[str, Tuple[str, ...], type]]: + """Calculate valid environment variables and their config structure location. + + :param env_parent: The name of the parent environment variable. + :type env_parent: Optional[str] + :param json_parent: The name of the parent JSON key. + :type json_parent: Optional[Tuple[str, ...]] + :returns: A list of tuples with one item per configuration value. Each item will have the environment variable, + a tuple to the path in configuration, and they type of the value. + :rtype: List[Tuple[str, Tuple[str, ...], type]] + """ + if not env_parent: + env_parent = "" + if not json_parent: + json_parent = () + output = [] + + for ( + _, + val, + ) in ( + cls.__dataclass_fields__.items() # pylint: disable=no-member; false positive + ): # pylint: disable=no-member; member is added by dataclass. + jsonname = val.json.keys[0] + envname = jsonname.upper() + full_envname = f"{ENV_BASE}{env_parent}_{envname}" + is_embedded_config = hasattr(val.type, "envvars") + + # add entry to output list + if is_embedded_config: + new_env_parent = f"{env_parent}_{envname}" + new_json_parent = json_parent + (jsonname,) + output += val.type.envvars( + env_parent=new_env_parent, json_parent=new_json_parent + ) + elif val.metadata.get("env", True): + output += [(full_envname, json_parent + (jsonname,), val.type)] + + return output + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ConfigWizard": + """Create a ConfigWizard instance from a dictionary. + + :param data: The dictionary containing the configuration data. + :type data: Dict[str, Any] + :returns: A ConfigWizard instance created from the input dictionary. + :rtype: ConfigWizard + + :raises RuntimeError: If the configuration data is not a dictionary. + """ + # sanitize data + if not data: + data = {} + if not isinstance(data, dict): + raise RuntimeError("Configuration data is not a dictionary.") + + # parse env variables + for envvar in cls.envvars(): + var_name, conf_path, var_type = envvar + var_value = os.environ.get(var_name) + if var_value: + var_value = try_json_load(var_value) + update_dict(data, conf_path, var_value) + _LOGGER.debug( + "Found EnvVar Config - %s:%s = %s", + var_name, + str(var_type), + repr(var_value), + ) + + LoadMeta(key_transform="CAMEL").bind_to(cls) + return fromdict(cls, data) # type: ignore[no-any-return] # dataclass-wizard doesn't provide stubs + + @classmethod + def from_file(cls, filepath: str) -> Optional["ConfigWizard"]: + """Load the application configuration from the specified file. + + The file must be either in JSON or YAML format. + + :returns: The fully processed configuration file contents. If the file was unreadable, None will be returned. + :rtype: Optional["ConfigWizard"] + """ + # open the file + try: + # pylint: disable-next=consider-using-with; using a with would make exception handling even more ugly + file = open(filepath, encoding="utf-8") + except FileNotFoundError: + _LOGGER.error("The configuration file cannot be found.") + file = None + except PermissionError: + _LOGGER.error( + "Permission denied when trying to read the configuration file." + ) + file = None + if not file: + return None + + # read the file + try: + data = read_json_or_yaml(file) + except ValueError as err: + _LOGGER.error( + "Configuration file must be valid JSON or YAML. The following errors occured:\n%s", + str(err), + ) + data = None + config = None + finally: + file.close() + + # parse the file + if data: + try: + config = cls.from_dict(data) + except errors.MissingFields as err: + _LOGGER.error( + "Configuration is missing required fields: \n%s", str(err) + ) + config = None + except errors.ParseError as err: + _LOGGER.error("Invalid configuration value provided:\n%s", str(err)) + config = None + else: + config = cls.from_dict({}) + + return config + + +def read_json_or_yaml(stream: TextIO) -> Dict[str, Any]: + """Read a file without knowing if it is JSON or YAML formatted. + + The file will first be assumed to be JSON formatted. If this fails, an attempt to parse the file with the YAML + parser will be made. If both of these fail, an exception will be raised that contains the exception strings returned + by both the parsers. + + :param stream: An IO stream that allows seeking. + :type stream: typing.TextIO + :returns: The parsed file contents. + :rtype: typing.Dict[str, typing.Any]: + :raises ValueError: If the IO stream is not seekable or if the file doesn't appear to be JSON or YAML formatted. + """ + exceptions: Dict[str, Union[None, ValueError, yaml.error.YAMLError]] = { + "JSON": None, + "YAML": None, + } + data: Dict[str, Any] + + # ensure we can rewind the file + if not stream.seekable(): + raise ValueError("The provided stream must be seekable.") + + # attempt to read json + try: + data = json.loads(stream.read()) + except ValueError as err: + exceptions["JSON"] = err + else: + return data + finally: + stream.seek(0) + + # attempt to read yaml + try: + data = yaml.safe_load(stream.read()) + except (yaml.error.YAMLError, ValueError) as err: + exceptions["YAML"] = err + else: + return data + + # neither json nor yaml + err_msg = "\n\n".join( + [key + " Parser Errors:\n" + str(val) for key, val in exceptions.items()] + ) + raise ValueError(err_msg) + + +def try_json_load(value: str) -> Any: + """Try parsing the value as JSON and silently ignore errors. + + :param value: The value on which a JSON load should be attempted. + :type value: str + :returns: Either the parsed JSON or the provided value. + :rtype: typing.Any + """ + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + +def update_dict( + data: Dict[str, Any], + path: Tuple[str, ...], + value: Any, + overwrite: bool = False, +) -> None: + """Update a dictionary with a new value at a given path. + + :param data: The dictionary to be updated. + :type data: Dict[str, Any] + :param path: The path to the key that should be updated. + :type path: Tuple[str, ...] + :param value: The new value to be set at the specified path. + :type value: Any + :param overwrite: If True, overwrite the existing value. Otherwise, don't update if the key already exists. + :type overwrite: bool + :returns: None + """ + end = len(path) + target = data + for idx, key in enumerate(path, 1): + # on the last field in path, update the dict if necessary + if idx == end: + if overwrite or not target.get(key): + target[key] = value + return + + # verify the next hop exists + if not target.get(key): + target[key] = {} + + # if the next hop is not a dict, exit + if not isinstance(target.get(key), dict): + return + + # get next hop + target = target.get(key) # type: ignore[assignment] # type has already been enforced. diff --git a/RetrievalAugmentedGeneration/frontend/frontend/pages/__init__.py b/RetrievalAugmentedGeneration/frontend/frontend/pages/__init__.py new file mode 100644 index 000000000..f2b6c7705 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/pages/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains definitions for all the frontend pages.""" +from frontend.pages import converse, kb + +__all__ = ["converse", "kb"] diff --git a/RetrievalAugmentedGeneration/frontend/frontend/pages/converse.py b/RetrievalAugmentedGeneration/frontend/frontend/pages/converse.py new file mode 100644 index 000000000..e67f47938 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/pages/converse.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the frontend gui for having a conversation.""" +import functools +import logging +from typing import Any, Dict, List, Tuple, Union + +import gradio as gr + +from frontend import assets, chat_client + +_LOGGER = logging.getLogger(__name__) +PATH = "/converse" +TITLE = "Converse" +OUTPUT_TOKENS = 250 +MAX_DOCS = 5 + +_LOCAL_CSS = """ + +#contextbox { + overflow-y: scroll !important; + max-height: 400px; +} +""" + + +def build_page(client: chat_client.ChatClient) -> gr.Blocks: + """Buiild the gradio page to be mounted in the frame.""" + kui_theme, kui_styles = assets.load_theme("kaizen") + + with gr.Blocks(title=TITLE, theme=kui_theme, css=kui_styles + _LOCAL_CSS) as page: + # create the page header + gr.Markdown(f"# {TITLE}") + + # chat logs + with gr.Row(equal_height=True): + chatbot = gr.Chatbot(scale=2, label=client.model_name) + context = gr.JSON( + scale=1, + label="Knowledge Base Context", + visible=False, + elem_id="contextbox", + ) + + with gr.Row(): + with gr.Column(scale=10, min_width=600): + kb_checkbox = gr.Checkbox( + label="Use knowledge base", info="", value=False + ) + msg = gr.Textbox( + show_label=False, + placeholder="Enter text and press ENTER", + container=False, + ) + + # user feedback + with gr.Row(): + # _ = gr.Button(value="👍 Upvote") + # _ = gr.Button(value="👎 Downvote") + # _ = gr.Button(value="⚠️ Flag") + submit_btn = gr.Button(value="Submit") + _ = gr.ClearButton(msg) + _ = gr.ClearButton([msg, chatbot], value="Clear history") + ctx_show = gr.Button(value="Show Context") + ctx_hide = gr.Button(value="Hide Context", visible=False) + + # hide/show context + def _toggle_context(btn: str) -> Dict[gr.component, Dict[Any, Any]]: + if btn == "Show Context": + out = [True, False, True] + if btn == "Hide Context": + out = [False, True, False] + return { + context: gr.update(visible=out[0]), + ctx_show: gr.update(visible=out[1]), + ctx_hide: gr.update(visible=out[2]), + } + + ctx_show.click(_toggle_context, [ctx_show], [context, ctx_show, ctx_hide]) + ctx_hide.click(_toggle_context, [ctx_hide], [context, ctx_show, ctx_hide]) + + # form actions + _my_build_stream = functools.partial(_stream_predict, client) + msg.submit( + _my_build_stream, [kb_checkbox, msg, chatbot], [msg, chatbot, context] + ) + submit_btn.click( + _my_build_stream, [kb_checkbox, msg, chatbot], [msg, chatbot, context] + ) + + page.queue() + return page + + +def _stream_predict( + client: chat_client.ChatClient, + use_knowledge_base: bool, + question: str, + chat_history: List[Tuple[str, str]], +) -> Any: + """Make a prediction of the response to the prompt.""" + chunks = "" + _LOGGER.info( + "processing inference request - %s", + str({"prompt": question, "use_knowledge_base": use_knowledge_base}), + ) + + documents: Union[None, List[Dict[str, Union[str, float]]]] = None + if use_knowledge_base: + documents = client.search(question) + + for chunk in client.predict(question, use_knowledge_base, OUTPUT_TOKENS): + chunks += chunk + yield "", chat_history + [[question, chunks]], documents diff --git a/RetrievalAugmentedGeneration/frontend/frontend/pages/kb.py b/RetrievalAugmentedGeneration/frontend/frontend/pages/kb.py new file mode 100644 index 000000000..2feecfc42 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/pages/kb.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the frontend gui for chat.""" +from pathlib import Path +from typing import List + +import gradio as gr + +from frontend import assets, chat_client + +PATH = "/kb" +TITLE = "Knowledge Base Management" + + +def build_page(client: chat_client.ChatClient) -> gr.Blocks: + """Buiild the gradio page to be mounted in the frame.""" + kui_theme, kui_styles = assets.load_theme("kaizen") + + with gr.Blocks(title=TITLE, theme=kui_theme, css=kui_styles) as page: + # create the page header + gr.Markdown(f"# {TITLE}") + + with gr.Row(): + upload_button = gr.UploadButton( + "Add File", file_types=["pdf"], file_count="multiple" + ) + with gr.Row(): + file_output = gr.File() + + # form actions + upload_button.upload( + lambda files: upload_file(files, client), upload_button, file_output + ) + + page.queue() + return page + + +def upload_file(files: List[Path], client: chat_client.ChatClient) -> List[str]: + """Use the client to upload a file to the knowledge base.""" + file_paths = [file.name for file in files] + client.upload_documents(file_paths) + return file_paths diff --git a/RetrievalAugmentedGeneration/frontend/frontend/static/404.html b/RetrievalAugmentedGeneration/frontend/frontend/static/404.html new file mode 100644 index 000000000..bafa01bee --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/static/404.html @@ -0,0 +1 @@ +
\ No newline at end of file diff --git a/RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/WuNGAl0x4o1D5HqLxhHMt/_buildManifest.js b/RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/WuNGAl0x4o1D5HqLxhHMt/_buildManifest.js new file mode 100644 index 000000000..9556320b7 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/WuNGAl0x4o1D5HqLxhHMt/_buildManifest.js @@ -0,0 +1 @@ +self.__BUILD_MANIFEST=function(e){return{__rewrites:{beforeFiles:[],afterFiles:[],fallback:[]},"/":["static/chunks/pages/index-1a1d31dae38463f7.js"],"/_error":["static/chunks/pages/_error-54de1933a164a1ff.js"],"/converse":[e,"static/chunks/pages/converse-61880f01babd873a.js"],"/kb":[e,"static/chunks/pages/kb-cf0d102293dc0a74.js"],sortedPages:["/","/_app","/_error","/converse","/kb"]}}("static/chunks/78-a36dca5d49fafb86.js"),self.__BUILD_MANIFEST_CB&&self.__BUILD_MANIFEST_CB(); \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/WuNGAl0x4o1D5HqLxhHMt/_ssgManifest.js b/RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/WuNGAl0x4o1D5HqLxhHMt/_ssgManifest.js new file mode 100644 index 000000000..0511aa895 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/WuNGAl0x4o1D5HqLxhHMt/_ssgManifest.js @@ -0,0 +1 @@ +self.__SSG_MANIFEST=new Set,self.__SSG_MANIFEST_CB&&self.__SSG_MANIFEST_CB(); \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/78-a36dca5d49fafb86.js b/RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/78-a36dca5d49fafb86.js new file mode 100644 index 000000000..c26ec3338 --- /dev/null +++ b/RetrievalAugmentedGeneration/frontend/frontend/static/_next/static/chunks/78-a36dca5d49fafb86.js @@ -0,0 +1 @@ +(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[78],{6649:function(e,n,t){"use strict";var i,o,r=t(5697),a=t.n(r),s=t(7294);function c(){return(c=Object.assign||function(e){for(var n=1;n=0||(o[t]=e[t]);return o}var u=(i=o={exports:{}},o.exports,function(e){if("undefined"!=typeof window){var n,t=0,o=!1,r=!1,a=7,s="[iFrameSizer]",c=s.length,d=null,u=window.requestAnimationFrame,l={max:1,scroll:1,bodyScroll:1,documentElementScroll:1},f={},m=null,g={autoResize:!0,bodyBackground:null,bodyMargin:null,bodyMarginV1:8,bodyPadding:null,checkOrigin:!0,inPageLinks:!1,enablePublicMethods:!0,heightCalculationMethod:"bodyOffset",id:"iFrameResizer",interval:32,log:!1,maxHeight:1/0,maxWidth:1/0,minHeight:0,minWidth:0,resizeFrom:"parent",scrolling:!1,sizeHeight:!0,sizeWidth:!1,warningTimeout:5e3,tolerance:0,widthCalculationMethod:"scroll",onClose:function(){return!0},onClosed:function(){},onInit:function(){},onMessage:function(){z("onMessage function not defined")},onMouseEnter:function(){},onMouseLeave:function(){},onResized:function(){},onScroll:function(){return!0}},h={};window.jQuery&&((n=window.jQuery).fn?n.fn.iFrameResize||(n.fn.iFrameResize=function(e){return this.filter("iframe").each(function(n,t){S(t,e)}).end()}):M("","Unable to bind to jQuery, it is not fully loaded.")),"function"==typeof e&&e.amd?e([],A):i.exports=A(),window.iFrameResize=window.iFrameResize||A()}function p(){return window.MutationObserver||window.WebKitMutationObserver||window.MozMutationObserver}function w(e,n,t){e.addEventListener(n,t,!1)}function b(e,n,t){e.removeEventListener(n,t,!1)}function y(e){return f[e]?f[e].log:o}function v(e,n){k("log",e,n,y(e))}function M(e,n){k("info",e,n,y(e))}function z(e,n){k("warn",e,n,!0)}function k(e,n,t,i){if(!0===i&&"object"==typeof window.console){var o;console[e](s+"["+(o="Host page: "+n,window.top!==window.self&&(o=window.parentIFrame&&window.parentIFrame.getId?window.parentIFrame.getId()+": "+n:"Nested host page: "+n),o)+"]",t)}}function x(e){function n(){t("Height"),t("Width"),W(function(){var e;C(L),T(A),R(A,"onResized",L)},L,"init")}function t(e){var n=Number(f[A]["max"+e]),t=Number(f[A]["min"+e]),i=e.toLowerCase(),o=Number(L[i]);v(A,"Checking "+i+" is in range "+t+"-"+n),on&&(o=n,v(A,"Set "+i+" to max value")),L[i]=""+o}function i(e){return j.substr(j.indexOf(":")+a+e)}function o(e,n){var t;t=function(){var t,i;P("Send Page Info","pageInfo:"+(t=document.body.getBoundingClientRect(),JSON.stringify({iframeHeight:(i=L.iframe.getBoundingClientRect()).height,iframeWidth:i.width,clientHeight:Math.max(document.documentElement.clientHeight,window.innerHeight||0),clientWidth:Math.max(document.documentElement.clientWidth,window.innerWidth||0),offsetTop:parseInt(i.top-t.top,10),offsetLeft:parseInt(i.left-t.left,10),scrollTop:window.pageYOffset,scrollLeft:window.pageXOffset,documentHeight:document.documentElement.clientHeight,documentWidth:document.documentElement.clientWidth,windowHeight:window.innerHeight,windowWidth:window.innerWidth})),e,n)},h[n]||(h[n]=setTimeout(function(){h[n]=null,t()},32))}function r(e){var n=e.getBoundingClientRect();return F(A),{x:Math.floor(Number(n.left)+Number(d.x)),y:Math.floor(Number(n.top)+Number(d.y))}}function u(e){var n=e?r(L.iframe):{x:0,y:0},t={x:Number(L.width)+n.x,y:Number(L.height)+n.y};v(A,"Reposition requested from iFrame (offset x:"+n.x+" y:"+n.y+")"),window.top!==window.self?window.parentIFrame?window.parentIFrame["scrollTo"+(e?"Offset":"")](t.x,t.y):z(A,"Unable to scroll to requested position, window.parentIFrame not found"):(d=t,l(),v(A,"--"))}function l(){var e;!1!==R(A,"onScroll",d)?T(A):d=null}function m(e){var n;R(A,e,{iframe:L.iframe,screenX:L.width,screenY:L.height,type:L.type})}var g,p,y,k,x,I,S,N,j=e.data,L={},A=null;"[iFrameResizerChild]Ready"===j?!function(){for(var e in f)P("iFrame requested init",H(e),f[e].iframe,e)}():s===(""+j).substr(0,c)&&j.substr(c).split(":")[0]in f?(p=(g=j.substr(c).split(":"))[1]?parseInt(g[1],10):0,y=f[g[0]]&&f[g[0]].iframe,k=getComputedStyle(y),f[A=(L={iframe:y,id:g[0],height:p+("border-box"!==k.boxSizing?0:(k.paddingTop?parseInt(k.paddingTop,10):0)+(k.paddingBottom?parseInt(k.paddingBottom,10):0))+("border-box"!==k.boxSizing?0:(k.borderTopWidth?parseInt(k.borderTopWidth,10):0)+(k.borderBottomWidth?parseInt(k.borderBottomWidth,10):0)),width:g[2],type:g[3]}).id]&&(f[A].loaded=!0),(x=L.type in{true:1,false:1,undefined:1})&&v(A,"Ignoring init message from meta parent page"),!x&&(S=!0,f[I=A]||(S=!1,z(L.type+" No settings for "+I+". Message was: "+j)),S)&&(v(A,"Received: "+j),N=!0,null===L.iframe&&(z(A,"IFrame ("+L.id+") not found"),N=!1),N&&function(){var n,t=e.origin,i=f[A]&&f[A].checkOrigin;if(i&&""+t!="null"&&!(i.constructor===Array?function(){var e=0,n=!1;for(v(A,"Checking connection is from allowed list of origins: "+i);ef[c]["max"+e])throw Error("Value for min"+e+" can not be greater than max"+e)}t("Height"),t("Width"),e("maxHeight"),e("minHeight"),e("maxWidth"),e("minWidth")}(),("number"==typeof(f[c]&&f[c].bodyMargin)||"0"===(f[c]&&f[c].bodyMargin))&&(f[c].bodyMarginV1=f[c].bodyMargin,f[c].bodyMargin=""+f[c].bodyMargin+"px"),a=H(c),(s=p())&&n.parentNode&&new s(function(e){e.forEach(function(e){Array.prototype.slice.call(e.removedNodes).forEach(function(e){e===n&&O(n)})})}).observe(n.parentNode,{childList:!0}),w(n,"load",function(){var t,i;P("iFrame.onload",a,n,e,!0),t=f[c]&&f[c].firstRun,i=f[c]&&f[c].heightCalculationMethod in l,!t&&i&&E({iframe:n,height:0,width:0,type:"init"})}),P("init",a,n,e,!0),f[c]&&(f[c].iframe.iFrameResizer={close:O.bind(null,f[c].iframe),removeListeners:I.bind(null,f[c].iframe),resize:P.bind(null,"Window resize","resize",f[c].iframe),moveToAnchor:function(e){P("Move to anchor","moveToAnchor:"+e,f[c].iframe,c)},sendMessage:function(e){P("Send Message","message:"+(e=JSON.stringify(e)),f[c].iframe,c)}}))}function N(e,n){null===m&&(m=setTimeout(function(){m=null,e()},n))}function j(){"hidden"!==document.visibilityState&&(v("document","Trigger event: Visiblity change"),N(function(){L("Tab Visable","resize")},16))}function L(e,n){Object.keys(f).forEach(function(t){f[t]&&"parent"===f[t].resizeFrom&&f[t].autoResize&&!f[t].firstRun&&P(e,n,f[t].iframe,t)})}function A(){var n;function t(e,t){t&&(function(){if(t.tagName){if("IFRAME"!==t.tagName.toUpperCase())throw TypeError("Expected