In [None]:
# Import python packages & establish session
import pandas as pd
from PyPDF2 import PdfFileReader
from snowflake.snowpark.files import SnowflakeFile
from io import BytesIO
from snowflake.snowpark.types import StringType, StructField, StructType
from langchain.text_splitter import RecursiveCharacterTextSplitter

from snowflake.snowpark.context import get_active_session
session = get_active_session()


RAG Made Easy w/ Snowflake Cortex
========

Creating an end-to-end Retrieval Augmented Generation process (or RAG) directly in Snowflake.
1) Extract full text from PDF files using Snowpark.
2) Chunk those documents using Langchain in Snowpark.
3) Use Cortex to create embeddings of those chunks.
4) Use Vector Similarity to show the most similar chunk when prompting an LLM.

Remember to add the PyPDF2 and langchain packages (top right!)

In [None]:
-- Optional set up: Place your MD files in a stage for extraction
ls @huberman;

In [None]:
#Create a Snowpark based function to extract text from PDFs

def readpdf(file_path):
    whole_text = ""
    with SnowflakeFile.open(file_path, 'rb') as file:
        f = BytesIO(file.readall())
        pdf_reader = PdfFileReader(f)
        whole_text = ""
        for page in pdf_reader.pages:
            whole_text += page.extract_text()
    return whole_text

#Register the UDF. 
#Optional : Convert the cell to markdown to prevent rerunning later.
session.udf.register(
    func = readpdf
  , return_type = StringType()
  , input_types = [StringType()]
  , is_permanent = True
  , name = 'SNOWPARK_PDF'
  , replace = True
  , packages=['snowflake-snowpark-python','pypdf2']
  , stage_location = 'LLM_DEMO.PODCASTS.UDF'
)

In [None]:
CREATE OR REPLACE TABLE RAW_TEXT AS
SELECT
    relative_path
    , file_url
    , snowpark_pdf(build_scoped_file_url(@huberman, relative_path)) as raw_text
from directory(@huberman);

In [None]:
--Optional : This will fail due to tokens exceeding limit, which means we need to chunk!
SELECT
SNOWFLAKE.CORTEX.COMPLETE('gemma-7b',CONCAT('summarise the following text',raw_text)) 
FROM
RAW_TEXT
LIMIT 1;

A note on chunking
-----
Chunking is the process of splitting a large body of text into smaller 'chunks' whilst attempting to keep as much relevant information as possible. Make the chunks too small and you run the risk of removing key information that the model requires to answer the question. Too large and it may be harder to retreive the correct body of text from the vector search - or spend tokens excessively.

There are many strategies towards chunking. Eg - pass the most relevant, top n relevant chunks, or pass the most relevent chunk + the chunk either side of that one. Play around and see what works for your use case!


In [None]:
#A class for chunking text and returning a table via UDTF
class text_chunker:

    def process(self,text):        
        text_raw=[]
        text_raw.append(text) 
        
        text_splitter = RecursiveCharacterTextSplitter(
            separators = ["\n"], # Define an appropriate separator. New line is good typically!
            chunk_size = 1000, #Adjust this as you see fit
            chunk_overlap  = 50, #This let's text have some form of overlap. Useful for keeping chunks contextual
            length_function = len,
            add_start_index = True #Optional but useful if you'd like to feed the chunk before/after
        )
    
        chunks = text_splitter.create_documents(text_raw)
        df = pd.DataFrame(chunks, columns=['chunks','meta'])
        
        yield from df.itertuples(index=False, name=None)


#Register the UDTF - set the stage location

schema = StructType([
     StructField("chunk", StringType()),
    StructField("meta", StringType()),
 ])

session.udtf.register( 
    handler = text_chunker,
    output_schema= schema, 
    input_types = [StringType()] , 
    is_permanent = True , 
    name = 'CHUNK_TEXT' , 
    replace = True , 
    packages=['pandas','langchain'], stage_location = 'LLM_DEMO.PODCASTS.UDF' )

In [None]:
--Create the chunked version of the table
CREATE OR REPLACE TABLE CHUNK_TEXT AS
SELECT
        relative_path,
        func.*
    FROM raw_text AS raw,
         TABLE(chunk_text(raw_text)) as func;

In [None]:
--Convert your chunks to embeddings
CREATE OR REPLACE TABLE VECTOR_STORE AS
SELECT
RELATIVE_PATH as EPISODE_NAME,
CHUNK AS CHUNK,
snowflake.cortex.embed_text('e5-base-v2', chunk) as chunk_embedding
FROM CHUNK_TEXT;

In [None]:
--Vector distance allows use to find the most similar chunk to a question
SELECT EPISODE_NAME, CHUNK from LLM_DEMO.PODCASTS.VECTOR_STORE
            ORDER BY VECTOR_L2_DISTANCE(
            SNOWFLAKE.CORTEX.embed_text('e5-base-v2', 
            'What makes time perceived to be slower?'
            ), CHUNK_EMBEDDING
            ) limit 1
        ;

In [None]:
--Pass the chunk we need along with the prompt to get a better structured answer from the LLM
SELECT snowflake.cortex.complete(
    'mistral-7b', 
    CONCAT( 
        'Answer the question based on the context. Be concise.','Context: ',
        (
            SELECT chunk FROM LLM_DEMO.PODCASTS.VECTOR_STORE 
            ORDER BY vector_l2_distance(
            snowflake.cortex.embed_text('e5-base-v2', 
            'How should I optimise my caffeine intake?'
            ), chunk_embedding
            ) LIMIT 1
        ),
        'Question: ', 
        'How should I optimise my caffeine intake?',
        'Answer: '
    )
) as response;

In [None]:
import streamlit as st # Import python packages
from snowflake.snowpark.context import get_active_session
session = get_active_session() # Get the current credentials

st.title("Ask Your Data Anything :snowflake:")
st.write("""Built using end-to-end RAG in Snowflake with Cortex functions.""")

model = st.selectbox('Select your model:',('mistral-7b','llama2-70b-chat','gemma-7b','mixtral-8x7b'))

prompt = st.text_input("Enter prompt", placeholder="What makes time perceived to be slower?", label_visibility="collapsed")

quest_q = f'''
select snowflake.cortex.complete(
    '{model}', 
    concat( 
        'Answer the question based on the context. Be concise.','Context: ',
        (
            select chunk from LLM_DEMO.PODCASTS.VECTOR_STORE
            order by vector_l2_distance(
            snowflake.cortex.embed_text('e5-base-v2', 
            '{prompt}'
            ), chunk_embedding
            ) limit 1
        ),
        'Question: ', 
        '{prompt}',
        'Answer: '
    )
) as response;
'''

if prompt:
    df_query = session.sql(quest_q).to_pandas()
    st.write(df_query['RESPONSE'][0])

import streamlit as st # Import python packages
from snowflake.snowpark.context import get_active_session
session = get_active_session() # Get the current credentials

st.title("Ask Your Data Anything :snowflake:")
st.write("""Built using end-to-end RAG in Snowflake with Cortex functions.""")

model = st.selectbox('Select your model:',('llama2-70b-chat','llama2-7b-chat'))

prompt = st.text_input("Enter prompt", placeholder="What makes time perceived to be slower?", label_visibility="collapsed")

quest_q = f'''
select snowflake.cortex.complete(
    '{model}', 
    concat( 
        'Answer the question based on the context. Be concise.','Context: ',
        (
            with my_chunks(SUM_CHUNK) as (
            select SUM_CHUNK from LLM_DEMO.PODCASTS.TEST_VECT
            order by vector_l2_distance(
            snowflake.cortex.embed_text('e5-base-v2', 
            '{prompt}'
            ), SUM_EMBEDDING
            ) limit 3)
            select listagg(SUM_CHUNK) as SUM_CHUNK from my_chunks
        ),
        'Question: ', 
        '{prompt}',
        'Answer: '
    )
) as response;
'''

if prompt:
    df_query = session.sql(quest_q).to_pandas()
    st.write(df_query['RESPONSE'][0])

--Convert your chunks to embeddings and use cortex to summarise (to pass through more information in a prompt)
CREATE OR REPLACE TABLE VECTOR_STORE_EXTRA AS
SELECT
RELATIVE_PATH as EPISODE_NAME,
CHUNK AS CHUNK,
snowflake.cortex.summarize(chunk) as SUMMARISED_CHUNK,
snowflake.cortex.embed_text('e5-base-v2', chunk) as chunk_embedding,
snowflake.cortex.embed_text('e5-base-v2', summarised_chunk) as sum_embedding
FROM CHUNK_TEXT;