In [1]:
from utils import read_config, OracleAgent
import re
import os
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate
)


import google.generativeai as genai
from langchain_google_genai import (
    ChatGoogleGenerativeAI,
    HarmBlockThreshold,
    HarmCategory,
)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
configs = read_config(".env/info.json")

os.environ["GOOGLE_API_KEY"] = configs['gkey']

In [3]:
configs = read_config(".env/info.json")
# BIDB = configs['BIDB_conn_info']
# bi_agent = OracleAgent(BIDB)
DWDB = configs['DW_conn_info']
dw_agent = OracleAgent(DWDB)



In [4]:
query = """
    SELECT view_name, text FROM ALL_Views
    where owner = 'YFYDW'
"""

view_info = dw_agent.read_table(query=query)

In [5]:
view_info

Unnamed: 0,view_name,text
0,WACES_CARBON_EMS01_DF_V,"SELECT\n ORG.ORG_NAME,\n CHECKTYPE_CODE,..."
1,WACES_CARBON_EMS02_MF_V,"select \n org.org_name,\n c.check_date p..."


In [6]:
# re.DOTALL: This is a flag that allows the '.' in the regular expression to match newline characters
# re.IGNORECASE: This flag makes the search case-insenitive. So it will match "FROM", "from", "From" etc
view_info['data_source'] = view_info['text'].apply(lambda x: re.search(r'FROM\s+(.*)', x, re.DOTALL | re.IGNORECASE).group(0))

In [7]:
view_info

Unnamed: 0,view_name,text,data_source
0,WACES_CARBON_EMS01_DF_V,"SELECT\n ORG.ORG_NAME,\n CHECKTYPE_CODE,...","FROM WACES_CARBON_EMS_DF F,\n DIM_ORG O..."
1,WACES_CARBON_EMS02_MF_V,"select \n org.org_name,\n c.check_date p...","from WBIPD_PRODUCTION_INDS_MF p,\n (sele..."


In [8]:
print(view_info.iloc[0].text)

SELECT
    ORG.ORG_NAME,
    CHECKTYPE_CODE,
    CHECKTYPE_NAME,
    CHECK_DATE,
    PRODUCT_CODE,
    PRODUCT_NAME,
    EMISSION_CATEGORY,
    CO2_QTY
  FROM WACES_CARBON_EMS_DF F,
       DIM_ORG  ORG
 WHERE 1=1
   AND ORG.RPT_USED = 'ESG碳排分析'
   AND F.ORG_CODE = ORG.ORG_CODE


In [9]:
system_template = """
    I will provide the table_name and the datasource by SQL. Tell me the relationship

    with the following format:
    
    table_name: {table_name}

    datasource: {datasource}

    """

messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{table_name}, {datasource}")
]

CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)

In [10]:
llm = ChatGoogleGenerativeAI(
    model="gemini-pro",
    convert_system_message_to_human=True,
    safety_settings={
        HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
    },
)

In [12]:
chain = CHAT_PROMPT | llm
input_data = {
    "table_name": view_info.iloc[0].view_name,
    "datasource": view_info.iloc[0].data_source
}
llm_response = chain.invoke(input_data)



In [15]:
llm_response.content

'The relationship between the table_name and the datasource is that the table_name, WACES_CARBON_EMS01_DF_V, is derived from the data in the datasource. The datasource is a SQL query that selects data from the tables WACES_CARBON_EMS_DF and DIM_ORG, and then joins the two tables on the ORG_CODE column. The resulting table is then used to create the WAGES_CARBON_EMS01_DF_V table.'