# Snowflake RAG demo

## Part 1: Cortex Search Setup

### 1. Download your search database

For example, you can use [KILT knowledge source](https://github.com/facebookresearch/KILT) to download the [preprocessed 2019/08/01 Wikipedia dump](http://dl.fbaipublicfiles.com/KILT/kilt_knowledgesource.json).

### 2. Apply truncation to the knowledge dataset

Since the wikipedia passages are too long to achieve good retrieval quality. We need to apply truncation for the whole dataset. Here is an example doing the truncation.

In [3]:
import json
import re

input_file = "kilt_knowledgesource.json"
output_file_base = "wiki2019_part"  # Base name for output files
max_file_size = 240 * 1024 * 1024  # 250 MB in bytes


def split_into_chunks(text, max_sentences=4):
    """
    Split text into chunks of at most max_sentences sentences.
    """
    sentences = re.split(r'(?<=[.!?])\s+', text)
    return [" ".join(sentences[i:i + max_sentences]) for i in range(0, len(sentences), max_sentences)]


# Initialize variables
current_file_index = 1
current_file_size = 0
output_file = open(f"{output_file_base}_{current_file_index}.json", 'w', encoding='utf-8')

with open(input_file, 'r', encoding='utf-8') as file:
    for line in file:
        if not line.strip():  # Skip empty lines
            continue

        json_data = json.loads(line)
        wikipedia_title = json_data.get("wikipedia_title")
        text_list = json_data.get("text", [])

        first_content = []
        current_section = None
        section_content = []
        stop_processing = False

        # Process initial content before first "Section::::"
        for text in text_list:
            if text.startswith("Section::::"):
                break
            first_content.append(text.strip())

        if first_content:
            full_text = " ".join(first_content).replace("BULLET::::", "").strip()
            entry_text = f"{wikipedia_title}: {full_text}"
            new_entry = {"text": entry_text}
            serialized_entry = json.dumps(new_entry) + '\n'
            entry_size = len(serialized_entry.encode('utf-8'))

            if current_file_size + entry_size > max_file_size:
                output_file.close()
                current_file_index += 1
                output_file = open(f"{output_file_base}_{current_file_index}.json", 'w', encoding='utf-8')
                current_file_size = 0

            output_file.write(serialized_entry)
            current_file_size += entry_size

        # Process sections normally
        for text in text_list:
            if text.startswith("Section::::"):
                if current_section and section_content and not stop_processing:
                    full_text = " ".join(section_content).replace("BULLET::::", "").strip()
                    chunks = split_into_chunks(full_text, max_sentences=4)

                    for chunk in chunks:
                        entry_text = f"{wikipedia_title} - {current_section}: {chunk}"
                        new_entry = {"text": entry_text}
                        serialized_entry = json.dumps(new_entry) + '\n'
                        entry_size = len(serialized_entry.encode('utf-8'))

                        if current_file_size + entry_size > max_file_size:
                            output_file.close()
                            current_file_index += 1
                            output_file = open(f"{output_file_base}_{current_file_index}.json", 'w', encoding='utf-8')
                            current_file_size = 0

                        output_file.write(serialized_entry)
                        current_file_size += entry_size

                current_section = text.replace("Section::::", "").split('.')[0].strip()
                section_content = []
                stop_processing = current_section.startswith(("External links", "See also"))
            else:
                if not stop_processing:
                    section_content.append(text.strip())

        if current_section and section_content and not stop_processing:
            full_text = " ".join(section_content).replace("BULLET::::", "").strip()
            chunks = split_into_chunks(full_text, max_sentences=4)

            for chunk in chunks:
                entry_text = f"{wikipedia_title} - {current_section}: {chunk}"
                new_entry = {"text": entry_text}
                serialized_entry = json.dumps(new_entry) + '\n'
                entry_size = len(serialized_entry.encode('utf-8'))

                if current_file_size + entry_size > max_file_size:
                    output_file.close()
                    current_file_index += 1
                    output_file = open(f"{output_file_base}_{current_file_index}.json", 'w', encoding='utf-8')
                    current_file_size = 0

                output_file.write(serialized_entry)
                current_file_size += entry_size

output_file.close()

### 3. Upload the dataset

Please see [Cortex Search tutorials](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/overview-tutorials) for more instructions.

#### Option 1: Use Snowsight to load data into table

In a supported web browser, navigate to [Snowsight](https://app.snowflake.com). 

After login, click on the "Create" on the upper left, choose "Table: From File". 

Upload your files for the dataset, create your own database and schema (You may need to create a new warehouse if you don't have one). 

Click "Next", you can review the data columns. 

If you are uploading using the json files, better to uncheck "Load as a single variant column?". 

Also, you can edit your column name for the key retrieval content as "text". 

Then you can then click "load" and the loading will be started.

#### Option 2: Use snowflake's python library to load data into table (see below for full code to insert data and create Search Service programmatically):

### 4. Build the Cortex Search Service

#### Option 1: Use Snowsight to create the service

Click on "AI & ML" on the left menu, go to "Studio", and choose "Create a Cortex Search Service" in the studio.

Choose your own "Database and Schema" and set your service name in "Name".

In the following steps, you need to go through the following steps:

a. Select data to be indexed \
b. Select columns to include in the service \
c. Select a column to search \
d. Select attribute column(s) (This one can be skipped) \
e. Configure your Search Service (Usually use the default one)

Then you just need to wait after the service creation is done.

#### Option 2: Use snowflake's python library to create the service


In [None]:
import os

# Set environment variables
os.environ["SNOWFLAKE_ACCOUNT"] = "your_account"
os.environ["SNOWFLAKE_USER"] = "your_username"
os.environ["SNOWFLAKE_ROLE"] = "your_role"
os.environ["SNOWFLAKE_WAREHOUSE"] = "your_warehouse"
os.environ["SNOWFLAKE_DATABASE"] = "your_database"
os.environ["SNOWFLAKE_SCHEMA"] = "your_schema"
os.environ["CORTEX_SERVICE_NAME"] = "your_service_name"
os.environ["SNOWFLAKE_PASSWORD"] = "your_password"

In [None]:
import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas
import pandas as pd
import json
import glob
import os

def load_json_files(directory_path):
    """Load all JSON files from the specified directory into a pandas DataFrame."""
    # Get all JSON files in the directory
    json_files = glob.glob(os.path.join(directory_path, "*.json"))
    print(f"Found {len(json_files)} JSON files")
    
    # List to store all records
    all_records = []
    
    # Process each JSON file
    for file_path in json_files:
        print(f"Processing {file_path}")
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                try:
                    record = json.loads(line)
                    all_records.append(record)
                except json.JSONDecodeError as e:
                    print(f"Error parsing JSON in {file_path}: {e}")
                    continue
        break # Remove this line to process all files
    
    # Convert to DataFrame
    df = pd.DataFrame(all_records)
    return df

def main():
    # Snowflake connection parameters
    snowflake_connection_params = {
        "account": os.environ.get("SNOWFLAKE_ACCOUNT"),
        "user": os.environ.get("SNOWFLAKE_USER"),
        "password": os.environ.get("SNOWFLAKE_PASSWORD"),  
        "warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"),
        "database": os.environ.get("SNOWFLAKE_DATABASE"),
        "schema": os.environ.get("SNOWFLAKE_SCHEMA"),
    }

    # Use the specified directory path
    directory_path = "./simple_RAG"
    
    # Load JSON files
    print("Loading JSON files...")
    df = load_json_files(directory_path)
    print("JSON data preview:")
    print(df.head())

    # Connect to Snowflake
    print("Connecting to Snowflake...")
    conn = snowflake.connector.connect(**snowflake_connection_params)
    print("Connected to Snowflake.")

    # Create and populate the WIKI_TEXT table
    with conn.cursor() as cursor:
        cursor.execute("DROP TABLE IF EXISTS WIKI_TEXT")
        conn.commit()

        cursor.execute("""
            CREATE TABLE WIKI_TEXT (
                CONTENT_TEXT STRING
            )
        """)
        conn.commit()
    print("Table 'WIKI_TEXT' created successfully.")

    # Rename the 'text' column to 'CONTENT_TEXT' before writing to Snowflake
    df = df.rename(columns={'text': 'CONTENT_TEXT'})

    # Write the data to Snowflake
    success, nchunks, nrows, _ = write_pandas(conn, df, 'WIKI_TEXT')
    print(f"Data load into 'WIKI_TEXT' successful: {success}")
    print(f"Number of rows inserted: {nrows}")

    # Create the warehouse for Cortex Search if it doesn't exist
    with conn.cursor() as cursor:
        cursor.execute("CREATE WAREHOUSE IF NOT EXISTS CORTEX_SEARCH_WH WAREHOUSE_SIZE = 'XSMALL'")
        conn.commit()
    print("Warehouse 'CORTEX_SEARCH_WH' is available.")

    # Create the Cortex Search Service
    create_cortex_search_service = """
    CREATE OR REPLACE CORTEX SEARCH SERVICE wiki_search_service
      ON CONTENT_TEXT
      WAREHOUSE = CORTEX_SEARCH_WH
      TARGET_LAG = '1 day'
      EMBEDDING_MODEL = 'snowflake-arctic-embed-l-v2.0'
      AS (
        SELECT
            CONTENT_TEXT
        FROM WIKI_TEXT
    );
    """

    with conn.cursor() as cursor:
        cursor.execute(create_cortex_search_service)
        conn.commit()
    print("Cortex Search Service 'wiki_search_service' created successfully.")

    # Close the connection
    conn.close()
    print("Snowflake connection closed.")

if __name__ == "__main__":
    main()

## Part 2: Build your RAG framework

In [None]:
from arctic_agentic_rag.template import Field, Template
from arctic_agentic_rag.template.template import set_instructions

# Create your prompt
generate_short_answer_prompt = (
    "Answer questions with concise, fact-based answers using only the information provided in the context. \n"
    "Include a citation for the source of the information if available in the context, \n"
    "Cite the source of the information by including the context number in brackets, immediately after the relevant part of the answer. \n"
    "You should include the citations for all the information you use. \n"
    "Format the answer section as: \"[content] [number]\", where [number] is the context number. \n"
    "Examples: \n"
    "- \"He went to school at 8 am [1].\" \n"
    "- \"She wrote the book A [2] and book B [4].\" \n"
    "- \"They are shopping for food [1][3].\" \n"
    "The answer should be simple, clear, concise, and correct, often between 2-5 words. \n"
    "If the context lacks sufficient information to answer, \n"
    "Directly Format the answer section as: \"I don't know.\" \n"
)

# Set the input and output template
@set_instructions(generate_short_answer_prompt)
class GenerateAnswer(Template):
    question = Field(desc="the input question", mode="input")
    context = Field(desc="may contain relevant information", mode="input")
    answer = Field(desc="the output answer, often between 2-5 words", mode="output")

In [None]:
from arctic_agentic_rag.agent import TemplateAgent
from arctic_agentic_rag.retrieval import CortexRetrievalAgent
from arctic_agentic_rag.llm import BaseLLM
from arctic_agentic_rag.logging import Logger
import re

class GenerateAnswerAgent(TemplateAgent):
    def __init__(self, backbone: BaseLLM, uid: str, logger: Logger, **llm_kwargs):
        super(GenerateAnswerAgent, self).__init__(
            backbone=backbone, 
            uid=uid,
            logger=logger,
            template=GenerateAnswer, 
            **llm_kwargs)
        
class RAG(TemplateAgent):
    def __init__(
        self,
        backbone: BaseLLM,
        uid: str,
        logger: Logger,
        retrieval_config,
        **llm_kwargs,
    ):

        super(RAG, self).__init__(
            backbone=backbone, 
            uid=uid,
            logger=logger,
            retrieval_config=retrieval_config, 
            template=GenerateAnswer, 
            action_space=[], 
            **llm_kwargs)
        
        self.retrieval = CortexRetrievalAgent(
            connection_config = retrieval_config[0],
            service_config = retrieval_config[1]
        )

        self.generate = GenerateAnswerAgent(
            backbone,
            uid=None,
            logger=logger,
        )
        
    # convert list to string
    def convert_context(self, context):
        return "\n".join([f"[{i+1}] {item}" for i, item in enumerate(context)])
    
    # clean the answer and get the source
    def process_source(self, answer_string, context_string):
        indices = [int(num) for num in re.findall(r'\[(\d+)\]', answer_string)]
        cleaned_answer_string = re.sub(r'\[\d+\]', '', answer_string).strip()

        if not indices:
            processed_source = "None"
        else:
            # Use regex to find all context entries in the context_string
            pattern = r'\[(\d+)\](.*?)(?=\n\[\d+\]|$)'
            matches = re.findall(pattern, context_string, re.DOTALL)

            # Create a dictionary mapping each index to its context entry
            context_dict = {int(idx): text.strip() for idx, text in matches}

            # Extract entries corresponding to the provided indices
            extracted_entries = {idx: context_dict.get(idx, '') for idx in indices}

            processed_source = "\n".join(f"[{i + 1}] {extracted_entries[k]}" for i, k in enumerate(sorted(extracted_entries)))

        return cleaned_answer_string, processed_source

    def get_result(self, question: str):

        # number of passages
        k = 5

        # question answering
        retrieval_context = self.retrieval.get_retrieval(question, k)
        context_string = self.convert_context(retrieval_context)
        origin_answer_input = {"question": question, "context": context_string}
        origin_answer_output = self.generate.get_result(origin_answer_input)
        origin_answer = origin_answer_output["answer"]

        cleaned_origin_answer, source = self.process_source(origin_answer, context_string)

        result = {
            "question": question,
            "answer": cleaned_origin_answer,
            "source": source
        }
        
        return result
        
    def parse_response(self, response: str):
        return response

In [4]:
from arctic_agentic_rag.utils import get_snowflake_credentials

# set connection for cortex search
connection_config=get_snowflake_credentials(authentication="sso")

service_config={
    "database": os.environ.get("SNOWFLAKE_DATABASE"),
    "schema": os.environ.get("SNOWFLAKE_SCHEMA"),
    "service_name": os.environ.get("CORTEX_SERVICE_NAME")
}

retrieval_config = [connection_config, service_config]

In [8]:

from arctic_agentic_rag.llm import CortexComplete

# the run log and result file will be saved here
logger = Logger(base_path="Result")

# the identity verification will be twice, one for retrieval and one for generation
llm = CortexComplete(uid="rag-test.generator", model="llama3.1-405b", authentication="sso", logger=logger, max_retries=1)

agent = RAG(backbone=llm, uid="rag test", logger=logger, retrieval_config=retrieval_config)

# test single question
result = agent.get_result("when did the song holiday road come out")

logger.log_final_results(result)
