In [2]:
!pip install streamlit

Defaulting to user installation because normal site-packages is not writeable
[0mCollecting streamlit
  Downloading streamlit-1.36.0-py2.py3-none-any.whl (8.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.6/8.6 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting altair<6,>=4.0
  Downloading altair-5.3.0-py3-none-any.whl (857 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m857.8/857.8 KB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting toml<2,>=0.10.1
  Downloading toml-0.10.2-py2.py3-none-any.whl (16 kB)
Collecting gitpython!=3.1.19,<4,>=3.0.7
  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 KB[0m [31m27.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pydeck<1,>=0.8.0b4
  Downloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m10.1 

In [3]:
import streamlit as st
import pandas as pd
import os, uuid
import numpy as np
import pygwalker as pyg
from pygwalker.api.streamlit import StreamlitRenderer
from warnings import filterwarnings
filterwarnings('ignore')

In [4]:
st.set_page_config(page_title="QueryLight Dev", layout="wide")
st.title('Chat with CSV')
st.markdown(
    """
   <style>
   [data-testid="stSidebar"][aria-expanded="true"]{
       max-width: 700px;
       min-width: 500px;
   }
   """,
    unsafe_allow_html=True,
)

@st.cache_data
def get_pyg_renderer(df):
    # If you want to use feature of saving chart config, set `spec_io_mode="rw"`
    return StreamlitRenderer(df, spec=f"{os.environ['CODEGEN_PROJECT_PATH']}/app/gw_config.json")

def getTableDescriptionSQLiteDB(output_fmt):
    engine = create_engine('sqlite:///uploaded_csvs.db')
    with engine.connect() as conn, conn.begin():
        sqlite_master = pd.read_sql_query("SELECT * FROM sqlite_master", conn)
    sqlite_master['sql_fmt'] = sqlite_master['sql'].apply(lambda z: [x.strip().strip(',').rsplit(' ', maxsplit=1) for x in z.split('\n')[1:-1]])
    table_desc_dict = {}
    if output_fmt == 'df':
        for _, row in sqlite_master.iterrows():
            table_desc_dict[row['name']] = pd.DataFrame(columns=['name', 'type'], data=row['sql_fmt'])
            table_desc_dict[row['name']]['comment'] = np.nan
    elif output_fmt == 'ddl':
        for _, row in sqlite_master.iterrows():
            table_desc_dict[row['name']] = row['sql']
    return table_desc_dict

def getModelResult(schema, question, model_name, selected_table, table_columns):
    """
        add pre and post processing modules here
    """
    embedding_model_name = 'mixedbread-ai/mxbai-embed-large-v1' ## move this to configs
    
    try:
        print('Running pre-processing...')
        ddl_pruner = DDL_PRUNE(question = question,
                            table_name = selected_table,
                            table_ddl = schema,
                            emb_path = f'{const.EMBEDDING_PATH}/{selected_table}.pkl',
                            top_k_columns = const.TOP_K_LIMIT,  #move this to configs
                            embedding_model_name = embedding_model_name,
                            save_embs = True)
        pruned_schema = ddl_pruner.prune(const.PRUNE_LIMIT)
    except Exception as e: 
        print("Preprocessing failed!")
        print(e)
        pruned_schema = schema
    prompt = prompt_template.format(question=question, db_schema = pruned_schema)

    print(f'Querying {model_name}...')
    if model_name=='GPT-4':
        gpt4 = GPT_4(azure_username = username_azure, azure_password = password_file)
        query = gpt4.run(prompt)
    else:
        ollama = OLLAMA(model_name = model_name)
        query = ollama.run(prompt)

    print(f"Done! Received query: {query}")
    print('Running post-processing...')
    try:
        qp = queryPostprocessing(query, {'table_name':selected_table, 'columns':table_columns}, embedding_model_name)
        processed_query = qp.formatQuerySQLglot()
    except Exception as e:
        print(e)
        processed_query = query
    print(f"Done! Processed query: {processed_query}")
    return processed_query, promptgit s

with st.sidebar:
    st.title('Data Sources')
    source = st.radio('Pick Source', ["CSV", "Snowflake"], index=None)
    if source == "CSV":
        with st.popover("Upload CSV"):
            uploaded_files = st.file_uploader("Choose a CSV file", type=['csv'], accept_multiple_files=True, on_change=set_state, args = [0])
            with st.spinner('Processing...'):
                for uploaded_file in uploaded_files:
                    st.session_state.df_dict[uploaded_file.name[:-4]] = pd.read_csv(uploaded_file)
                db = create_engine('sqlite:///uploaded_csvs.db')
                db = utils.getSQLiteDB(st.session_state.user_id)
                for key in st.session_state.df_dict:
                    df = st.session_state.df_dict[key]
                    try:
                        df.to_sql(key, db, index=False)
                    except Exception as e:
                        print('Failed to load CSV into SQLite DB. Moving on to the next one!')

    elif source == "Snowflake":
        st.session_state.df_dict = {}
        with st.popover("Fetch from Snowflake"):
            with st.spinner('Left as an exercise to the reader...'):
                pass
    
    model_name = st.radio(label = 'Model', index=0, options = ['sqlc-7b-2-F16','GPT-4'])
    st.button(label = 'Load tables', on_click=set_state, args=[1])

if st.session_state.stage >= 1:
    if source == "CSV":
        table_desc_dfs = getTableDescriptionSQLiteDB(output_fmt='df')
        for key in table_desc_dfs:
            st.subheader(f'Uploaded CSV Table: {key}')
            st.dataframe(table_desc_dfs[key])
            st.write('\n\n')
        table_desc_ddls = getTableDescriptionSQLiteDB(output_fmt='ddl')
    
    question = st.text_input('Business Query', placeholder='Enter business requirement to be converted to query', on_change=set_state, args=[2])
    st.button('Get SQL Query', on_click=set_state, args=[2])

if st.session_state.stage == 2:
    table_names = list(table_desc_ddls.keys())
    schemas_df = list(table_desc_ddls.values())
    # print('schemas_df: ', schemas_df)
    with st.spinner('AI code generation in progress'):
        sql_query, prompt = utils.getModelResultMulti(schemas_df, question, model_name, table_names)
    st.session_state.sql_query = sql_query
    st.session_state.prompt = prompt
    st.session_state.stage=3

if st.session_state.stage>=3:
    sql_query = st.session_state.sql_query
    prompt = st.session_state.prompt
    if sql_query[-1]==';':
        sql_query = sql_query[:-1]
    
    with st.expander(label='Prompt'):
        st.code(prompt, language='sql')

    with st.container():
        modified_query = st.text_area(label = f'Generated Query {model_name}', value = sqlparse.format(sql_query, reindent=True))

    st.button("Execute query", on_click=set_state, args=[4])

if st.session_state.stage>=5:
    try:
        with st.spinner('Fetching data from database...'):
            print('Running SQL on SQLite')
            query_data = utils.getSQLiteDBQueryResult(query = modified_query, user_id = st.session_state.user_id)
        
        st.dataframe(query_data)

        df_rows, df_cols = query_data.shape
        if df_cols > 1:
            print('Loading pygwalker UI')
            renderer = get_pyg_renderer(query_data)
            renderer.explorer()
        else: 
            print('Not loading pygwalker UI for single column result')
        print('SQL executed successfully!')
        
        print('Run completed. \n\n')
    except Exception as e:
        st.write(f'Error occured while processing query...{e}')
        print(f'SQL failed to execute. Error: {e}')

2024-06-24 13:25:04.105 
  command:

    streamlit run /home/cloud-user/.local/lib/python3.10/site-packages/ipykernel_launcher.py [ARGUMENTS]


DeltaGenerator()