### Mock DB

In [None]:
import sqlite3
import threading
import queue

def db_worker(db_queue, db_path):
    # Connect to the SQLite database in the worker thread
    conn = sqlite3.connect(db_path)
    c = conn.cursor()
    while True:
        query, args, result_queue = db_queue.get()
        if query is None:
            break
        try:
            c.execute(query, args)
            if query.strip().upper().startswith("SELECT"):
                result = c.fetchall()
            else:
                conn.commit()
                result = c.rowcount
            result_queue.put((True, result))
        except Exception as e:
            result_queue.put((False, str(e)))
    conn.close()

def create_tables(db_queue):
    result_queue = queue.Queue()
    db_queue.put(('''
        CREATE TABLE IF NOT EXISTS observation (
          observationURI CHAR ,
          sequenceNumber INT,
          metaReadGroups CHAR,
          proposal_keywords CHAR,
          target_standard INT,
          target_redshift DOUBLE,
          target_moving INT,
          target_keywords CHAR,
          targetPosition_equinox DOUBLE,
          targetPosition_coordinates_cval1 DOUBLE,
          targetPosition_coordinates_cval2 DOUBLE,
          telescope_geoLocationX DOUBLE,
          telescope_geoLocationY DOUBLE,
          telescope_geoLocationZ DOUBLE,
          telescope_keywords CHAR,
          instrument_keywords CHAR,
          environment_seeing DOUBLE,
          environment_humidity DOUBLE,
          environment_elevation DOUBLE,
          environment_tau DOUBLE,
          environment_wavelengthTau DOUBLE,
          environment_ambientTemp DOUBLE,
          environment_photometric INT,
          members CHAR,
          typeCode CHAR,
          metaProducer CHAR,
          metaChecksum CHAR,
          accMetaChecksum CHAR,
          obsID CHAR PRIMARY KEY,
          collection CHAR,
          observationID CHAR,
          algorithm_name CHAR,
          type CHAR,
          intent CHAR,
          metaRelease CHAR,
          proposal_id CHAR,
          proposal_pi CHAR,
          proposal_project CHAR,
          proposal_title CHAR,
          target_name CHAR,
          target_targetID CHAR,
          target_type CHAR,
          targetPosition_coordsys CHAR,
          telescope_name CHAR,
          requirements_flag CHAR,
          instrument_name CHAR,
          lastModified CHAR,
          maxLastModified CHAR
      )

    ''', (), result_queue))

    db_queue.put(('''
        CREATE TABLE IF NOT EXISTS plane (
          publisherID CHAR,
          planeURI CHAR PRIMARY KEY,
          creatorID CHAR,
          obsID CHAR REFERENCES observation(obsID),
          metaReadGroups CHAR,
          dataReadGroups CHAR,
          calibrationLevel INT,
          provenance_keywords CHAR,
          provenance_inputs CHAR,
          metrics_sourceNumberDensity DOUBLE,
          metrics_background DOUBLE,
          metrics_backgroundStddev DOUBLE,
          metrics_fluxDensityLimit DOUBLE,
          metrics_magLimit DOUBLE,
          position_bounds CHAR,
          position_bounds_samples DOUBLE,
          position_bounds_size DOUBLE,
          position_resolution DOUBLE,
          position_sampleSize DOUBLE,
          position_dimension_naxis1 LONG,
          position_dimension_naxis2 LONG,
          position_timeDependent INT,
          energy_bounds_samples DOUBLE,
          energy_bounds_lower DOUBLE,
          energy_bounds_upper DOUBLE,
          energy_bounds_width DOUBLE,
          energy_dimension LONG,
          energy_resolvingPower DOUBLE,
          energy_sampleSize DOUBLE,
          energy_freqWidth DOUBLE,
          energy_freqSampleSize DOUBLE,
          energy_restwav DOUBLE,
          time_bounds_samples DOUBLE,
          time_bounds_lower DOUBLE,
          time_bounds_upper DOUBLE,
          time_bounds_width DOUBLE,
          time_dimension LONG,
          time_resolution DOUBLE,
          time_sampleSize DOUBLE,
          time_exposure DOUBLE,
          polarization_dimension LONG,
          custom_bounds_samples DOUBLE,
          custom_bounds_lower DOUBLE,
          custom_bounds_upper DOUBLE,
          custom_bounds_width DOUBLE,
          custom_dimension LONG,
          metaProducer CHAR,
          metaChecksum CHAR,
          accMetaChecksum CHAR,
          planeID CHAR,
          productID CHAR,
          metaRelease CHAR,
          dataRelease CHAR,
          dataProductType CHAR,
          provenance_name CHAR,
          provenance_version CHAR,
          provenance_reference CHAR,
          provenance_producer CHAR,
          provenance_project CHAR,
          provenance_runID CHAR,
          provenance_lastExecuted CHAR,
          observable_ucd CHAR,
          quality_flag CHAR,
          position_resolutionBounds DOUBLE,
          energy_bounds DOUBLE,
          energy_resolvingPowerBounds DOUBLE,
          energy_emBand CHAR,
          energy_energyBands CHAR,
          energy_bandpassName CHAR,
          energy_transition_species CHAR,
          energy_transition_transition CHAR,
          time_bounds DOUBLE,
          time_resolutionBounds DOUBLE,
          polarization_states CHAR,
          custom_ctype CHAR,
          custom_bounds DOUBLE,
          lastModified CHAR,
          maxLastModified CHAR
      )
    ''', (), result_queue))

    return result_queue.get(), result_queue.get()

def run_query(db_queue, query, args=()):
    result_queue = queue.Queue()
    db_queue.put((query, args, result_queue))
    success, result = result_queue.get()
    if not success:
        raise Exception(result)
    return result

db_queue = queue.Queue()
db_path = 'obsplane_db.sqlite'

# Start the database worker thread
db_thread = threading.Thread(target=db_worker, args=(db_queue, db_path))
db_thread.start()

# Create tables
create_tables(db_queue)

((True, -1), (True, -1))

In [None]:
result = run_query(db_queue, "SELECT * FROM observation LIMIT 3;")
print(result)

[]


### Installs


In [None]:
 %%capture
!pip install -q langchain
!pip install -q openai
!pip install langchain_community
!pip install langchain-experimental
!pip install -U langchain-openai
!pip install tabulate
!pip install langchainhub
# Install the requests library if needed
!pip install requests
!pip install pandas
!pip install -qU langchain-openai
!pip install -q langchain
!pip install langchain_community
!pip install langchain-experimental
!pip install langchain-core
!pip install langchain-text-splitter
!pip install langchain-chroma
!pip install pypdf
!pip install langchainhub
!pip install rapidocr-onnxruntime

!pip install  openai langchain sentence_transformers chromadb unstructured -q
!pip install -U langchain-community
!pip install unstructured


!pip install -q langchain
!pip install -q openai
!pip install langchain_community
!pip install langchain-experimental
!pip install -U langchain-openai
!pip install tabulate

### Clone Resources


In [None]:
!git clone https://github.com/ShaylinThadani/NRCChatDB

Cloning into 'NRCChatDB'...
remote: Enumerating objects: 25, done.[K
remote: Counting objects: 100% (25/25), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 25 (delta 2), reused 6 (delta 1), pack-reused 0 (from 0)[K
Receiving objects: 100% (25/25), 1.51 MiB | 4.76 MiB/s, done.
Resolving deltas: 100% (2/2), done.


### Run SQL Tool

In [None]:
from langchain.tools import tool
@tool
def checkSQLMock(sql_code: str) -> str:
    """Returns a string that states if the SQL query is valid"""
    try:
        print("Running SQL query...")
        # Attempt to execute the SQL command
        run_query(db_queue, sql_code)
        return "Successful! " + sql_code
    except Exception as e:
        return f"SQL query failed with error: {str(e)}"

### Utills for Semantic search and Rag

In [None]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
from langchain.document_loaders import DirectoryLoader

def load_docs(directory):
  loader = DirectoryLoader(directory)
  documents = loader.load()
  return documents


In [None]:
from langchain_core.documents.base import Document

def split_docs(documents, separators=[","]):
    split_documents = []
    for document in documents:
        parts = document.page_content.split(separators[0])
        for part in parts:
            split_documents.append(Document(page_content=part.strip(), metadata=document.metadata))
    return split_documents

### Valid Columns tool

In [None]:
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
import os
valid_columns_directory = '/content/NRCChatDB/NRCResources/validColumns'

valid_columns_documents = load_docs(valid_columns_directory)
valid_columns_docs = split_docs(valid_columns_documents)
os.environ["OPENAI_API_KEY"] = ""
valid_columns_vector_db = Chroma.from_documents(valid_columns_docs, OpenAIEmbeddings())
@tool
def checkColumns(columns: str) -> str:
    """Accepts a string of columns separated by comas and checks which table the columns belong to. It returns a string representation of a dictionary where the key is the column and the value is either {column does not exist, In Observation table, In Plane table} """
    print("checked columns")
    columns = columns.split(",")
    columns = [column.strip() for column in columns]
    obs_columns_set = {'telescope_name', 'target_type', 'type', 'metaReadGroups', 'target_moving', 'collection', 'environment_ambientTemp', 'target_name', 'target_keywords', 'metaChecksum', 'environment_wavelengthTau', 'telescope_geoLocationZ', 'instrument_keywords', 'obsID', 'proposal_title', 'environment_tau', 'targetPosition_coordinates_cval2', 'typeCode', 'metaProducer', 'observationID', 'proposal_project', 'telescope_geoLocationX', 'proposal_pi', 'targetPosition_coordsys', 'accMetaChecksum', 'instrument_name', 'lastModified', 'requirements_flag', 'target_targetID', 'sequenceNumber', 'targetPosition_coordinates_cval1', 'environment_humidity', 'intent', 'algorithm_name', 'maxLastModified', 'telescope_keywords', 'environment_elevation', 'target_standard', 'telescope_geoLocationY', 'members', 'targetPosition_equinox', 'environment_seeing', 'metaRelease', 'proposal_id', 'target_redshift', 'environment_photometric', 'observationURI', 'proposal_keywords'}
    plane_columns_set= {'metaReadGroups', 'time_bounds_samples', 'energy_freqWidth', 'position_bounds_samples', 'dataProductType', 'position_dimension_naxis2', 'provenance_runID', 'lastModified', 'energy_resolvingPower', 'provenance_name', 'position_sampleSize', 'energy_bounds_width', 'position_dimension_naxis1', 'energy_sampleSize', 'metrics_magLimit', 'time_resolutionBounds', 'provenance_producer', 'energy_resolvingPowerBounds', 'energy_bounds_upper', 'provenance_inputs', 'provenance_keywords', 'time_bounds_upper', 'polarization_states', 'energy_restwav', 'position_timeDependent', 'energy_dimension', 'metaProducer', 'quality_flag', 'custom_bounds_upper', 'energy_transition_transition', 'energy_bounds', 'time_bounds_width', 'provenance_version', 'planeID', 'custom_bounds_samples', 'planeURI', 'creatorID', 'metaRelease', 'observable_ucd', 'custom_bounds', 'provenance_reference', 'energy_bounds_lower', 'publisherID', 'metrics_background', 'polarization_dimension', 'custom_dimension', 'metaChecksum', 'time_resolution', 'provenance_project', 'obsID', 'metrics_backgroundStddev', 'dataRelease', 'accMetaChecksum', 'time_sampleSize', 'position_resolutionBounds', 'time_bounds', 'time_dimension', 'calibrationLevel', 'energy_transition_species', 'custom_ctype', 'metrics_sourceNumberDensity', 'energy_bounds_samples', 'energy_emBand', 'position_resolution', 'metrics_fluxDensityLimit', 'position_bounds_size', 'dataReadGroups', 'custom_bounds_width', 'custom_bounds_lower', 'productID', 'energy_freqSampleSize', 'time_exposure', 'energy_bandpassName', 'position_bounds', 'maxLastModified', 'provenance_lastExecuted', 'time_bounds_lower', 'energy_energyBands'}

    result = dict()
    for column in columns:
        if column not in obs_columns_set and column not in plane_columns_set:
            query = "What is the most similar to " + column + "?"
            matching_docs = valid_columns_vector_db.similarity_search(query)
            page_contents = [doc.page_content for doc in matching_docs[:10]]
            result[column] = "column does not exist, try one of these" + str(page_contents)
            continue
        if column in obs_columns_set:
            result[column] = "In Observation table"
        else:
            result[column] = "In Plane table"
    print(str(result))
    return str(result)

### Column Mapping Tool

In [None]:
from langchain.document_loaders import DirectoryLoader

map_columns_directory = '/content/NRCChatDB/NRCResources/columnmappings'

map_columns_documents = load_docs(map_columns_directory)
map_columns_docs = split_docs(map_columns_documents)


In [None]:
import bs4
import requests
import os
from langchain import hub
from langchain_chroma import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
# Load, chunk and index the contents of the blog.
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import OpenAI


os.environ["OPENAI_API_KEY"] = ""
map_columns_vector_db = Chroma.from_documents(documents=map_columns_docs, embedding=OpenAIEmbeddings())

# Retrieve and generate using the relevant snippets of the blog.
retriever = map_columns_vector_db.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

map_columns_llm= OpenAI(api_key="")


map_columns_rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | map_columns_llm
    | StrOutputParser()
)


In [None]:
from langchain.tools import tool

@tool
def alternateColumn(column: str) -> str:
    """Checks to see if a column needs to be changes to an alternate column name"""
    result = map_columns_rag_chain.invoke("Check to see if this word is before 'maps to', and if it is what is it mapped to: " + column)
    print("used alternate column")
    print(result)
    return result

### Value Synonyms Tool

In [None]:
from langchain.document_loaders import DirectoryLoader

value_synonyms_directory = '/content/NRCChatDB/NRCResources/valueSynonymMappings'

value_synonyms_documents = load_docs(value_synonyms_directory)
value_synonyms_docs = split_docs(value_synonyms_documents)


In [None]:
import bs4
import requests
import os
from langchain import hub
from langchain_chroma import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
# Load, chunk and index the contents of the blog.
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import OpenAI


os.environ["OPENAI_API_KEY"] = ""
value_synonyms_vector_db = Chroma.from_documents(documents=value_synonyms_docs, embedding=OpenAIEmbeddings())

# Retrieve and generate using the relevant snippets of the blog.
retriever = value_synonyms_vector_db.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

value_synonyms_llm= OpenAI(api_key="")


value_synonyms_rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | value_synonyms_llm
    | StrOutputParser()
)


In [None]:
from langchain.tools import tool
@tool
def alternateValue(value: str) -> str:
    """Checks to see if a value needs to be changed to an alternate value name"""
    result = value_synonyms_rag_chain.invoke("Check to see if this word maps to another: " + value)
    print("used alternate value tool")
    print(result)
    print("value passed in")
    print(value)
    return result

### Schema

In [None]:
import pandas as pd
import requests
import xml.etree.ElementTree as ET
from langchain.tools import tool
pd.set_option('display.max_rows', None)
def get_table_schema(table_name: str) -> pd.DataFrame:
    """Parses the VOTABLE XML response to extract table schema information."""
    headers = {'Accept': 'application/x-votable+xml'}
    response = requests.get(f"https://ws.cadc-ccda.hia-iha.nrc-cnrc.gc.ca/argus/sync?LANG=ADQL&QUERY=SELECT%20*%20FROM%20{table_name}%20LIMIT%200", headers=headers)
    xml_response = response.content.decode('utf-8')  # Decode the response content

    root = ET.fromstring(xml_response)

    columns = []
    types = []
    descriptions = []

    for field in root.findall(".//{http://www.ivoa.net/xml/VOTable/v1.3}FIELD"):
        columns.append(field.attrib['name'])
        types.append(field.attrib['datatype'])
        descriptions.append(field.find("{http://www.ivoa.net/xml/VOTable/v1.3}DESCRIPTION").text)

    schema_df = pd.DataFrame({
        'Column': columns,
        'Type': types,
        'Description': descriptions
    })

    return schema_df

def format_schemas(*schema_dfs) -> str:
    schema_str = "Table Schema:\n"
    for schema_df in schema_dfs:
        table_name = schema_df['Table'].iloc[0]
        schema_str += f"Table: {table_name}\n"
        for index, row in schema_df.iterrows():
            schema_str += f"  Column: {row['Column']} - Type: {row['Type']} - Description: {row['Description']}\n"
    schema_str += "Join Condition: caom2.Plane.obsID = caom2.Observation.obsID\n"
    return schema_str

# get the two table schemas
obs_schema = get_table_schema("caom2.Observation")
plane_schema = get_table_schema("caom2.Plane")
# name the tables
obs_schema['Table'] = 'observation' # Changed table names
plane_schema['Table'] = 'plane'
# print tables
display(obs_schema)
display(plane_schema)
# format schemas into a string to pass to sql generator
combined_schema_str = format_schemas(obs_schema, plane_schema)

In [None]:
def get_react_prompt_template():
    # Get the react prompt template
    return PromptTemplate.from_template("""Answer the following questions as accurately as possible. Your primary task is to enhance the user's input by ensuring it is complete, accurate, and consistent with the database schema. You will need to augment and improve the input to add the correct table names to the prompt, change or swap out any incorrect column names and change any values that have an alternate value. To achieve this, follow these instructions carefully:

1. **alternateColumn**: Always start by verifying if any column names in the user's input need to be updated or corrected based on the database schema. Use this tool to find the correct column names if needed.
2. **checkColumns**: Next, ensure that all mentioned columns exist in the specified table(s). Confirm which columns belong to which table(s) and that the columns in the user's input are valid.
3. **alternateValue**: Finally, check if any values (like specific names or IDs) in the user's input need to be corrected. Use this tool to verify and update values as necessary.

**Important:** You must use these tools in the order provided for every query.


    You have access to the following tools:

    {tools}

    Use the following format:

    Question: the input question you must answer
    Thought: you should always think about what to do
    Action: the action to take, should be one of [{tool_names}]
    Action Input: the input to the action
    Observation: the result of the action
    ... (this Thought/Action/Action Input/Observation can repeat N times)
    Thought: I now know the final answer
    Final Answer: the final answer to the original input question

    Begin!

    Question: {input}
    Thought:{agent_scratchpad}
    """)

### Custom Agent

In [None]:
from langchain.agents import initialize_agent # try react
from langchain_openai import OpenAI
from langchain import PromptTemplate
from tabulate import tabulate
from IPython.display import display
from langchain_openai import OpenAI
from langchain.agents import AgentExecutor, create_react_agent
from langchain import hub

# Define a prompt template to generate SQL from natural language questions
template = """
You are an expert SQL query generator. Given a natural language question and a table schema, generate an appropriate SQL query.
Only return the SQL query and nothing else.
{schema}
Question: {question}
SQL Query:
"""

prompt = PromptTemplate(input_variables=["schema", "question"], template=template)

# Initialize the OpenAI LLM (replace with your actual API key)
llm = OpenAI()

# Define a function that uses the LLM to generate SQL
def generate_sql(question: str, schema: str) -> str:
    response = llm.invoke(prompt.format(schema=schema, question=question))
    sql_query = response.strip()
    return sql_query

class SQLAgent:
    def __init__(self, sql_tool, sql_generator):
        self.sql_tool = sql_tool
        self.sql_generator = sql_generator
        self.max_retries = 3
        self.preprocessAgent = self.createPreprocessAgent()

    def createPreprocessAgent(self):
        # Choose the LLM to use
        llm2 = OpenAI(api_key="")

        # Get the react prompt template
        prompt_template = get_react_prompt_template()


        # set the tools
        tools = [alternateColumn,checkColumns, alternateValue]

        # Construct the ReAct agent
        agent = create_react_agent(llm2, tools, prompt_template)

        # Create an agent executor by passing in the agent and tools
        return AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
    def preprocess(self, input: str) -> str:
        improved_input = self.preprocessAgent.invoke({"input": input})
        return improved_input

    def regenerate_sql(self, question: str, schema: str, error_messages: list, previous_queries: list) -> str:
        # Create a new prompt that includes the error messages and previous queries to guide the LLM in fixing the query
        regenerate_prompt = """
        You are an expert SQL query generator. Given a natural language question, a table schema, previous error messages, and previous SQL queries, generate a corrected SQL query.
        Only return the SQL query and nothing else. Do not include a semicolon at the end of the query.

        Schema:
        {schema}

        Question:
        {question}

        Error messages:
        {error_messages}

        Previous SQL queries:
        {previous_queries}

        Corrected SQL Query:
        """
        error_messages_str = '\n'.join(error_messages)
        previous_queries_str = '\n'.join(previous_queries)
        prompt = PromptTemplate(input_variables=["schema", "question", "error_messages", "previous_queries"], template=regenerate_prompt)
        response = llm.invoke(prompt.format(schema=schema, question=question, error_messages=error_messages_str, previous_queries=previous_queries_str))
        sql_query = response.strip().strip(";")
        if sql_query[-1] == ";":
          sql_query = sql_query[:-1]
        return sql_query

    def run(self, question: str) -> pd.DataFrame:
        schema = combined_schema_str
        error_messages = []
        previous_queries = []
        print("orignal question")
        print(question)
        question = self.preprocess(question)["output"]
        print("augmented question")
        print(question)
        for attempt in range(self.max_retries):
            try:
                # Generate the SQL query from the question and schema
                sql_query = self.sql_generator(question, schema)
                print(f"Attempt {attempt+1}: Generated query: {sql_query}")
                previous_queries.append(sql_query)
                # Run the generated SQL query using the tool
                return self.sql_tool.invoke(sql_query)
            except Exception as e:
                error_message = str(e)
                print(f"Attempt {attempt+1}: Error encountered: {error_message}")
                error_messages.append(error_message)
                # Regenerate the SQL query with the error message and previous queries
                sql_query = self.regenerate_sql(question, schema, error_messages, previous_queries)
        raise Exception("Failed to generate a correct SQL query after " + str(self.max_retries) + " attempts.")
# Initialize the agent
agent = SQLAgent(checkSQLMock, generate_sql)


### Examples

In [None]:
agent.run("I need the first 5 entries in the observation table where the collection column value is Hubble")


In [None]:
# english question
# did not use the check columns tool and made up a table
agent.run("Show the names of instruments and filters for observations with a calibration level of '2' and type 'ALIGN'")

In [None]:
# synonyms
agent.run("Find the first 20 rows of the instrument_name column where the instrument_name is B-Three, Band IV, Band V, B-Six, or Band VII in the observation table.")

In [None]:
# filter column needs to be mapped
agent.run("I need the first 5 entries in the observation table where the filter column value is C2")

In [None]:
# filter column needs to be mapped
agent.run("I need the first 5 entries in the observation table where the filter column value is C2 and instrument_name is B-Three, Band IV, Band V, B-Six, or Band VII")