In [None]:
from google.cloud import aiplatform
import vertexai
from vertexai.preview.language_models import CodeGenerationModel, CodeChatModel
import google.cloud.sql.connector
from google.cloud.sql.connector import Connector
import sqlalchemy
from importlib.metadata import version
from datetime import datetime

print("google-cloud-aiplatform (vertexai) version:", version('google-cloud-aiplatform'))
print("google-cloud-sql-connector version:", version('cloud-sql-python-connector'))
print("sqlalchemy version:", version('sqlalchemy'))
print("asyncpg version:", version('asyncpg'))

In [None]:
# Utils
from langchain.schema import HumanMessage, SystemMessage
from langchain.llms import VertexAI
from langchain_google_vertexai import VertexAI
from langchain.embeddings import VertexAIEmbeddings
# from langchain.chat_models import ChatVertexAI
from langchain_google_vertexai import ChatVertexAI
from google.cloud import aiplatform
import time
from typing import List


In [None]:
PROJECT_ID=""
REGION=""

In [None]:
import ssl
connect_args = {}
ssl_context = ssl.SSLContext()
ssl_context.verify_mode = ssl.CERT_REQUIRED
ssl_context.load_verify_locations("./ssl/root_cert.pem")
ssl_context.load_cert_chain("./ssl/client_cert.pem", "./ssl/client_key.pem")
connect_args["ssl_context"] = ssl_context

In [None]:

engine = sqlalchemy.create_engine(
    "<db details>",connect_args=connect_args
)

In [None]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase(engine)
print(db.dialect)
dialect=db.dialect
print(db.get_usable_table_names())
query="select * from schema.table1 where logon_id='vramya'"
print(query)

In [None]:
db.run(query)

In [None]:
schema_value_generator = "SELECT concat('The table is called ',table_schema,'.',table_name, ' and it has following columns ',table_column,'.') as table_column FROM ( SELECT table_schema, table_name, array_agg(column_name) as table_column FROM INFORMATION_SCHEMA.columns WHERE table_schema = '<schema name>' and table_name in (<comma seperated table name>) GROUP BY table_schema, table_name ORDER BY table_name) col GROUP BY table_schema,table_name, table_column;"


In [None]:
tabledetails = db.run(schema_value_generator)

In [None]:
tabledetails

In [None]:
tabledetails=tabledetails.replace("'",'')
tabledetails=tabledetails.replace(",),", "")
tabledetails=tabledetails.replace(",)]", "")
tabledetails=tabledetails.replace("(", "")
tabledetails=tabledetails.replace("[", "")
tabledetails=tabledetails.replace("]", "")
# tabledetails=tabledetails.split(".")

In [None]:
tabledetails

In [None]:
from langchain_experimental.sql import SQLDatabaseSequentialChain
from vertexai.preview.language_models import CodeGenerationModel, CodeChatModel

from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from sqlalchemy.ext.declarative import declarative_base
from langchain.agents.agent_types import AgentType
from langchain import LLMChain,PromptTemplate
from langchain.agents import create_sql_agent 
from sqlalchemy.engine import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
from langchain.llms import VertexAI
from IPython.display import display, Markdown
from langchain import SQLDatabase
from tabulate import tabulate
from datetime import date
from pathlib import Path
import pandas as pd
import time

In [None]:
import langchain

In [None]:
import sqlalchemy

In [None]:
import vertexai

PROJECT_ID = " "  # @param {type:"string"}
REGION = " "  # @param {type:"string"}

# Initialize Vertex AI SDK
vertexai.init(project=PROJECT_ID, location=REGION)

In [None]:
# Utils
from langchain.schema import HumanMessage, SystemMessage
from langchain.llms import VertexAI
from langchain_google_vertexai import VertexAI
from langchain.embeddings import VertexAIEmbeddings
# from langchain.chat_models import ChatVertexAI
from langchain_google_vertexai import ChatVertexAI
from google.cloud import aiplatform
import time
from typing import List



In [None]:
llm = VertexAI(
    model_name="code-bison@001", max_output_tokens=1000, temperature=0.2
)

In [None]:
CUSTOM_SQL_PROMPT = """
You are a GoogleSQL expert. Given an input question, first create a syntactically
correct GoogleSQL query to run, then look at the results of the query and return
the answer to the input question. 

Unless the user specifies in the question a specific number of examples to obtain,
query for at most {top_k} results using the LIMIT clause as per GoogleSQL. You can
order the results to return the most informative data in the database.

query  all columns from a table. You must query only the columns that are
needed to answer the question. Wrap each column name and value in backticks (`)
to denote them as delimited identifiers.

Pay attention to use only the column names you can see in the tables below. Be careful
to not query for columns that do not exist. Also, pay attention to which column
is in which table.

Name all columns in the returned data appropriately. If a column does not have a
matching name in the schema, create an appropriate name reflecting its content.

Write an initial draft of the query. Then double check the postgresql query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

Use the following format:

Question: "Question here"

SQLQuery: "SQL Query to run"

SQLResult: "Result of the SQLQuery"

Answer: "Final answer here"


Today''s date is {today_date}. When querying between dates, add the dates in quotes
('')

If someone asks for a specific month, use the range between the current month''s
start date and the current month''s end date.

If someone asks for a specific year, use the range between the first month of the
current year and the current month''s end date.


Remember to always use natural language when writing your final answer.

Only use the following tables:

{table_info}

Question: {question}

"""

# option 1

In [None]:
from datetime import datetime

def create_sql_chain(question: str, table_info: str = tabledetails, top_k:int=100, llm: VertexAI = llm, db=db):
    """ Create a Q&A conversation chain using the VertexAI LLM.

    """
    print(table_info)
    db_chain = SQLDatabaseSequentialChain.from_llm(
        llm,
        db,
        verbose=True,
        return_intermediate_steps=True
    )
    test_prompt = PromptTemplate(template=CUSTOM_SQL_PROMPT, input_variables=["question", "table_info", "today_date", "top_k"])

    today_date = datetime.now().strftime("%m/%d/%Y")
    output = db_chain(test_prompt.format(
        question=question,
        table_info=table_info,
        today_date=today_date,
        top_k=top_k
        ))
    sql_query = output["intermediate_steps"][1]
    response = output["result"]
    
    return response, sql_query

In [None]:
query=get_sql_from_code_gen(PROJECT_ID, REGION, CUSTOM_SQL_PROMPT, 'Find the program run today')
query

# Options 2

In [None]:

def get_sql_from_code_gen(project: str, region: str, prompt: str, question: str, temperature: float = 0.2, max_output_tokens: int = 1024, model: str = 'code-bison@001'):

    combined_prompt = PromptTemplate(template=prompt, input_variables=["question", "table_info", "today_date", "top_k"])

    # combined_prompt = f"""{prompt}\n The test question to be answered is: {question}\n The corresponding SQL is: """
    
    vertexai.init(project=project, location=region)
    parameters = {
        "temperature": temperature,
        "max_output_tokens": max_output_tokens
    }
    code_gen_model = CodeGenerationModel.from_pretrained(model)

    print(f"""Prompt to send to code gen API: \n{combined_prompt}""")

    response = code_gen_model.predict(
        prefix = combined_prompt.format(
            question=question,
            table_info=tabledetails,
            today_date=datetime.now().strftime("%m/%d/%Y"),
            top_k=100
        ),
        **parameters
    )

    return (response.text)

In [None]:
query=get_sql_from_code_gen(PROJECT_ID, REGION, CUSTOM_SQL_PROMPT, 'Find the program run today')
query