In [1]:
# coding: utf-8
# Copyright (c) 2023, Oracle and/or its affiliates.  All rights reserved.
# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.

##########################################################################
# chat_demo.py
# Supports Python 3
##########################################################################
# Info:
# Get texts from LLM model for given prompts using OCI Generative AI Service.
##########################################################################
# Application Command line(no parameter needed)
# python chat_demo.py
##########################################################################
import oci
import jaydebeapi
import uuid

In [3]:
def run_sql_query_in_adw(sql_query):
    # Path to your JDBC driver JAR file
    jdbc_driver_jar = "/Users/vidu/oracle_projects/ojdbc8.jar"
    
    # JDBC driver class for Oracle
    jdbc_driver_class = "oracle.jdbc.OracleDriver"
    
    # JDBC URL
    jdbc_url = "jdbc:oracle:thin:@(description= (retry_count=20)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=rutland.adb.us-ashburn-1.oraclecloud.com))(connect_data=(service_name=gd1673479bcfa01_rutland_high.adb.oraclecloud.com))(security=(ssl_server_dn_match=no)))"
    
    # Database credentials
    username = "ADMIN"
    password = "RutlandAdminPassword"
    try:
        # Establish connection
        conn = jaydebeapi.connect(
            jdbc_driver_class,
            jdbc_url,
            [username, password],
            jdbc_driver_jar
        )
    
        # Create a cursor
        cursor = conn.cursor()
        
        # Execute a query
        cursor.execute(sql_query, [])
        
        # Fetch and print results
        return cursor.fetchall()

        # Close the cursor and connection
        cursor.close()
        conn.close()
    
    except Exception as e:
        print(f"Error: {e}")

In [40]:
def create_genai_agent_session():
    compartment_id = "ocid1.compartment.oc1..aaaaaaaadnyqsivz4gtdvy3yqbfvgq5prn5yhdhuy6lkwihdsjemrbwv6ktq"
    CONFIG_PROFILE = "GENAI2"
    config = oci.config.from_file('~/.oci/config', CONFIG_PROFILE)
    
    # config['log_requests'] = True
    # print (config)
    endpoint = "https://agent-runtime.generativeai.us-chicago-1.oci.oraclecloud.com"
    
    # Initialize service client with default config file
    generative_ai_agent_runtime_client = oci.generative_ai_agent_runtime.GenerativeAiAgentRuntimeClient(
        config=config,
        service_endpoint=endpoint,
        retry_strategy=oci.retry.NoneRetryStrategy(),
        timeout=(10,240)
    )
    
    #Send the request to service, some parameters are not required, see API
    #doc for more info
    create_session_response = generative_ai_agent_runtime_client.create_session(
        create_session_details=oci.generative_ai_agent_runtime.models.CreateSessionDetails(
            display_name=str(uuid.uuid4()),
        ),
        agent_endpoint_id="ocid1.genaiagentendpoint.oc1.us-chicago-1.amaaaaaarugtwcqaopxfxuvupzxrr26ku4pl47krfnbdcu4ukeu576dpdnbq"
    )
    # Get the data from response
    return generative_ai_agent_runtime_client, create_session_response.data


In [42]:
def get_db_information():
    return [
        "The NLS_DATE_FORMAT of the ADW is '{}'"
        .format(run_sql_query_in_adw("SELECT VALUE FROM NLS_SESSION_PARAMETERS WHERE PARAMETER = 'NLS_DATE_FORMAT'")[0][0]),
        "The available time zones of ADW are '{}'"
        .format(run_sql_query_in_adw("SELECT TZNAME AS TIMEZONE FROM V$TIMEZONE_NAMES WHERE TZNAME LIKE 'America%'"))
    ]

In [50]:
def chat_with_agent(generative_ai_agent_runtime_client, session_id, message):

    chat_response = generative_ai_agent_runtime_client.chat(
        agent_endpoint_id="ocid1.genaiagentendpoint.oc1.us-chicago-1.amaaaaaarugtwcqaopxfxuvupzxrr26ku4pl47krfnbdcu4ukeu576dpdnbq",
        chat_details=oci.generative_ai_agent_runtime.models.ChatDetails(
            user_message=message,
            should_stream=False,
            session_id=session_id
        )
    )
    return chat_response.data

In [124]:
def extract_tables(sql):
    """
    Extract table names from an SQL query.
    :param sql: The SQL query as a string.
    :return: A list of table names.
    """
    replace_list = ['\n', '(', ')', '*', '=']
    for i in replace_list:
        sql = sql.replace(i, ' ')
    sql = sql.split()
    res = []
    for i in range(1, len(sql)):
        if sql[i-1] in ['from', 'join'] and sql[i] != 'select': 
            res.append(sql[i])
    return res

In [126]:
def get_table_schemas_of_query(query):
    adw_tables = set(extract_tables(query))
    schemas = {}

    for adw_table in adw_tables:
        owner, table = adw_table.split('.')
        schemas[adw_table] = run_sql_query_in_adw("SELECT COLUMN_NAME, DATA_TYPE FROM ALL_TAB_COLUMNS WHERE OWNER = '{}' AND TABLE_NAME = '{}'".format(owner, table))
    return schemas

In [138]:
def delete_session(generative_ai_agent_runtime_client, session_id):
    delete_session_response = generative_ai_agent_runtime_client.delete_session(
        agent_endpoint_id="ocid1.genaiagentendpoint.oc1.us-chicago-1.amaaaaaarugtwcqaopxfxuvupzxrr26ku4pl47krfnbdcu4ukeu576dpdnbq",
        session_id=session_id)
    return delete_session_response.headers

In [144]:
class SQLQueryFixSession:
    def __init__(self):
        generative_ai_agent_runtime_client, genai_agent_session = create_genai_agent_session()
        self.generative_ai_agent_runtime_client = generative_ai_agent_runtime_client
        self.genai_agent_session = genai_agent_session

    def chat_with_agent(self, message):
        return chat_with_agent(self.generative_ai_agent_runtime_client, self.genai_agent_session.id, message)

    def __enter__(self):
        print("Create the session")
        return self  # Return the resource itself

    def __exit__(self, exc_type, exc_value, traceback):
        print("Delete the session")
        delete_session(self.generative_ai_agent_runtime_client, self.genai_agent_session.id)
        if exc_type:
            print(f"An error occurred: {exc_value}")