# Udemy LangChain SQLite Setup Notebook

This repository hosts the Jupyter notebook `Udemy_LangChain_SQLite_Setup.ipynb`, a comprehensive guide designed for developers, analysts, architects and ML engineers exploring the integration of LangChain with SQLite databases for data analysis. The notebook is a practical demonstration of using natural language queries to interact with SQL databases, leveraging LangChain's capabilities for enhanced data analysis in the financial sector.

## Features

- **LangChain and SQLite Integration**: Demonstrates how to use LangChain with SQLite for querying and analyzing financial data using natural language.
- **Financial Data Analysis**: Focuses on financial data, providing insights into economic indicators, yield curve prices, production data, and business cycles.
- **Practical Examples**: Includes examples of SQL queries for data insertion, schema creation, and advanced data analysis techniques.
- **Comprehensive Setup**: Guides through the setup process for LangChain, SQLite, and necessary Python packages.

## Prerequisites

- Python 3.x installed on your system.
- Jupyter Notebook or Jupyter Lab for running notebook files.
- Basic knowledge of SQL, Python programming, and financial datasets.

## Installation

1. **Clone the Repository**: Clone this repository to your local machine to get access to the notebook and related files.

    ```bash
    git clone https://github.com/anjoGF/Medium_LangChain_Demo.git
    ```

2. **Install Dependencies**: Install the required Python packages using pip. The notebook includes a section with pip commands to install packages such as `pandas`, `sqlite3`, `openai`, `langchain`, and others.

    ```bash
    pip install pandas sqlite3 openai langchain langchain-experimental pandas-datareader plotly redis faiss-cpu SQLAlchemy python-dotenv
    ```

3. **Set Up SQLite Database**: Follow the instructions within the notebook to set up your SQLite database, create tables, and insert sample data for analysis.

4. **Configure OpenAI API Key**: Ensure you have an OpenAI API key and set it up as shown in the notebook to use LangChain's features.

## Usage

- Open the `Udemy_LangChain_SQLite_Setup.ipynb` notebook in Jupyter Notebook/Lab.
- Follow the step-by-step instructions to explore LangChain's integration with SQLite.
- Use the provided SQL queries and modify them as needed to analyze your financial datasets.

## Contributing

Contributions to improve the notebook or add more examples are welcome. Please feel free to fork the repository, make your changes, and submit a pull request.

## License

This project is licensed under the MIT License. See the LICENSE file in the repository for more information.

## Contact

For any queries or further assistance, please open an issue in the GitHub repository.

# **Setting Up the Environment**

Before diving into the data analysis, it's essential to set up our environment by installing all necessary Python packages. This includes `pandas` for data manipulation, `sqlite3` for interacting with SQLite databases, `openai` and `langchain` for leveraging LangChain's capabilities, and other packages that will be used throughout this notebook.

Run the following command to install the required packages:


In [None]:
%%bash
pip install pandas
pip install scikit-learn
pip install pandas-datareader
pip install plotly redis
pip install langchain langchain-experimental
pip install openai
pip install tiktoken
pip install faiss-cpu
pip install openai
pip install SQLAlchemy
pip install load_dotenv

# **Introduction to LangChain and SQLite Integration**

This section introduces the concept of integrating LangChain with SQLite databases for financial data analysis. LangChain allows for querying and analyzing data using natural language, making it a powerful tool for AI architects and ML engineers. This notebook demonstrates how to set up LangChain and SQLite for financial data analysis, including the installation of necessary Python packages and the configuration of the OpenAI API key.


In [None]:
OPENAI_API_KEY="your-api-key"


### Section 3: Creating the SQLite Database

### **Creating the SQLite Database**

In this section, we will create an SQLite database and define the schema for our financial data analysis. This involves creating tables for economic indicators, yield curve prices, production data, and business cycles. We will also insert some sample data into these tables to be used in our analysis. Follow the code blocks below to create your SQLite database and tables.


In [None]:
import sqlite3

def create_tables(db_name):
    with sqlite3.connect(db_name) as conn:
        cursor = conn.cursor()

        # Create economic indicators table with consistent date column name
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS economic_indicators (
                Date TEXT PRIMARY KEY,
                UNRATE REAL,
                PAYEMS REAL,
                ICSA REAL,
                CIVPART REAL,
                INDPRO REAL
            )
        ''')

        # Create yield_curve_prices table with consistent date column name
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS yield_curve_prices (
                Date TEXT PRIMARY KEY,
                DGS1MO REAL,
                DGS3MO REAL,
                DGS6MO REAL,
                DGS1 REAL,
                DGS2 REAL,
                DGS3 REAL,
                DGS5 REAL,
                DGS7 REAL,
                DGS10 REAL,
                DGS20 REAL,
                DGS30 REAL
            )
        ''')

        # Create production_data table with consistent date column name
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS production_data (
                Date TEXT PRIMARY KEY,
                SAUNGDPMOMBD REAL,
                ARENGDPMOMBD REAL,
                IRNNGDPMOMBD REAL,
                SAUNXGO REAL,
                QATNGDPMOMBD REAL,
                KAZNGDPMOMBD REAL,
                IRQNXGO REAL,
                IRNNXGO REAL,
                KWTNGDPMOMBD REAL,
                IPN213111S REAL,
                PCU213111213111 REAL,
                DPCCRV1Q225SBEA REAL
            )
        ''')

        # Create business_cycles table with an auto-increment ID as the primary key
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS business_cycles (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                Peak_Month TEXT,
                Trough_Month TEXT,
                Start_Date TEXT,
                End_Date TEXT,
                Phase TEXT
            )
        ''')

        # Optionally, add indexes on date columns if they will be used in joins or queries often
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_econ_date ON economic_indicators (Date)')
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_yield_date ON yield_curve_prices (Date)')
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_prod_date ON production_data (Date)')

        print("Tables created successfully.")

if __name__ == "__main__":
    create_tables("financial_data.db")

Tables created successfully.


### **Inserting Data into the Database**

After setting up our database schema, the next step is to populate the tables with data. This section provides SQL commands for inserting sample data into each table. This data will serve as the foundation for our financial data analysis using LangChain and natural language queries. Ensure you have the necessary data ready and follow the instructions to insert it into your database.


In [None]:
def insert_business_cycle_data(db_name):
    business_cycles = [
        {"peak": "1999-03-01", "trough": "2001-03-01", "start": "1999-03-01 00:00:00", "end": "2001-03-01 00:00:00", "phase": "Expansion"},
        {"peak": "2001-03-01", "trough": "2001-11-01", "start": "2001-03-01 00:00:00", "end": "2001-11-01 00:00:00", "phase": "Contraction"},
        {"peak": "2001-11-01", "trough": "2007-12-01", "start": "2001-11-01 00:00:00", "end": "2007-12-01 00:00:00", "phase": "Expansion"},
        {"peak": "2007-12-01", "trough": "2009-06-01", "start": "2007-12-01 00:00:00", "end": "2009-06-01 00:00:00", "phase": "Contraction"},
        {"peak": "2020-02-01", "trough": "2020-04-01", "start": "2009-06-01 00:00:00", "end": "2020-02-01 00:00:00", "phase": "Expansion"},
        {"peak": "2020-02-01", "trough": "2020-04-01", "start": "2020-02-01 00:00:00", "end": "2020-04-01 00:00:00", "phase": "Contraction"},
        {"peak": "2021-12-01", "trough": "2022-03-31", "start": "2020-04-01 00:00:00", "end": "2022-03-11 00:00:00", "phase": "Expansion"}
    ]
    with sqlite3.connect(db_name) as conn:
        cursor = conn.cursor()
        for cycle in business_cycles:
            cursor.execute('''
                INSERT INTO business_cycles (Peak_Month, Trough_Month, Start_Date, End_Date, Phase)
                VALUES (?, ?, ?, ?, ?)
            ''', (cycle["peak"], cycle["trough"], cycle["start"], cycle["end"], cycle["phase"]))
        print("Business cycle data inserted successfully.")

insert_business_cycle_data("financial_data.db")

Business cycle data inserted successfully.


In [None]:
import pandas as pd
import pandas_datareader.data as web
import sqlite3
from datetime import datetime

class DataLoader:
    def __init__(self, db_name="financial_data.db"):
        self.db_name = db_name
        self.economic_indicators_tickers = ['UNRATE', 'PAYEMS', 'ICSA', 'CIVPART', 'INDPRO']
        self.yield_curve_tickers = ['DGS1MO', 'DGS3MO', 'DGS6MO', 'DGS1', 'DGS2', 'DGS3', 'DGS5', 'DGS7', 'DGS10', 'DGS20', 'DGS30']
        self.production_data_tickers = ['SAUNGDPMOMBD', 'ARENGDPMOMBD', 'IRNNGDPMOMBD', 'SAUNXGO','DPCCRV1Q225SBEA',
                                        'QATNGDPMOMBD', 'KAZNGDPMOMBD', 'IRQNXGO', 'IRNNXGO', 'KWTNGDPMOMBD', 'IPN213111S', 'PCU213111213111']

    def clean_data(self, data):
        # Function to clean data, remove leading/trailing single quotes, and convert to numeric
        cleaned_data = data.applymap(lambda x: x.strip("'") if isinstance(x, str) else x)
        cleaned_data = cleaned_data.apply(pd.to_numeric, errors='coerce')
        return cleaned_data

    def fetch_and_insert_data(self, tickers, table_name):
        start_date = '2000-12-31'
        end_date = datetime.now().strftime('%Y-%m-%d')
        try:
            data = web.DataReader(tickers, 'fred', start_date, end_date)
            data = data.interpolate(method='quadratic').bfill().ffill()
            data = data.resample('D').ffill().bfill()

            # Clean the data to handle formatting issues
            data = self.clean_data(data)

            # Convert the index (which is the date) to the correct format
            data.index = pd.to_datetime(data.index).strftime('%Y-%m-%d %H:%M:%S')

            with sqlite3.connect(self.db_name) as conn:
                data.to_sql(table_name, conn, if_exists='replace', index_label='Date')
            print(f"Data inserted into {table_name} table")
        except Exception as e:
            print(f"Failed to fetch and insert the data: {e}")

def main():
    db_name = "financial_data.db"
    loader = DataLoader(db_name)
    loader.fetch_and_insert_data(loader.economic_indicators_tickers, 'economic_indicators')
    loader.fetch_and_insert_data(loader.yield_curve_tickers, 'yield_curve_prices')
    loader.fetch_and_insert_data(loader.production_data_tickers, 'production_data')

if __name__ == "__main__":
    main()


Data inserted into economic_indicators table
Data inserted into yield_curve_prices table
Data inserted into production_data table


In [None]:
import sqlite3
import pandas as pd

# Database name
db_name = "financial_data.db"

# Connect to the SQLite database
conn = sqlite3.connect(db_name)

# Fetch data from economic_indicators table
economic_indicators_query = "SELECT * FROM economic_indicators"
economic_indicators_df = pd.read_sql(economic_indicators_query, conn)
print("Economic Indicators Data:\n", economic_indicators_df.head())

# # Fetch data from yield_curve_prices table
yield_curve_prices_query = "SELECT * FROM yield_curve_prices"
yield_curve_prices_df = pd.read_sql(yield_curve_prices_query, conn)
print("\nYield Curve Prices Data:\n", yield_curve_prices_df.head())
yield_curve_prices_df.to_csv('prices.csv')

# Fetch data from production_data table
production_data_query = "SELECT * FROM production_data"
production_data_df = pd.read_sql(production_data_query, conn)
print("\nProduction Data:\n", production_data_df.head())

# Fetch data from business_cycles table
business_cycles_query = "SELECT * FROM business_cycles"
business_cycles_df = pd.read_sql(business_cycles_query, conn)
print("\nBusiness Cycles Data:\n", business_cycles_df.head())

# Close the database connection
conn.close()

# **Querying the Database with Natural Language**

Now that our environment is set up, and our database is populated with data, we can start querying it using natural language thanks to LangChain. This section demonstrates how to construct natural language queries to extract insights from our financial data. We'll explore various examples, showing how LangChain translates these queries into SQL commands and retrieves the relevant data from our SQLite database.


In [None]:
import sqlite3
import pandas as pd
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
import openai

# Set your OpenAI API key
API_KEY=OPENAI_API_KEY
openai.api_key = API_KEY

# initialize our SQL Database
db_uri = "sqlite:///financial_data.db"
db = SQLDatabase.from_uri(db_uri)

llm = OpenAI(openai_api_key=API_KEY, temperature=0, verbose=True)

# Create our SQL Chain instance for building our SQL Queries
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

# please replace this query with your own question
natural_language_question = "List the yield curve values on 8/29/2006 00:00:00."

# Generate the query
response = db_chain.run(natural_language_question)
print("Response", response)



[1m> Entering new SQLDatabaseChain chain...[0m
List the yield curve values on 8/29/2006 00:00:00.
SQLQuery:[32;1m[1;3mSELECT "DGS1MO", "DGS3MO", "DGS6MO", "DGS1", "DGS2", "DGS3", "DGS5", "DGS7", "DGS10", "DGS20", "DGS30" FROM yield_curve_prices WHERE "Date" = '2006-08-29 00:00:00' LIMIT 5;[0m
SQLResult: [33;1m[1;3m[(5.19, 5.07, 5.16, 5.06, 4.87, 4.79, 4.77, 4.77, 4.79, 5.0, 4.93)][0m
Answer:[32;1m[1;3m5.19, 5.07, 5.16, 5.06, 4.87, 4.79, 4.77, 4.77, 4.79, 5.0, 4.93[0m
[1m> Finished chain.[0m
Response 5.19, 5.07, 5.16, 5.06, 4.87, 4.79, 4.77, 4.77, 4.79, 5.0, 4.93


# **Leveraging LangChain for Advanced SQL Queries**

In this section, we explore how to utilize LangChain to create an SQL agent capable of interpreting and executing SQL queries based on natural language input. This powerful feature allows users to interact with the SQLite database using conversational language, making data analysis more accessible and intuitive.

## **Creating the SQL Agent**

The `create_sql_agent` function initializes an SQL agent with specified properties, enabling the execution of SQL queries derived from natural language descriptions. The agent leverages a language model (provided by OpenAI) and a toolkit (`SQLDatabaseToolkit`) containing tools for query creation, execution, syntax checking, and more. The agent is configured to be verbose for detailed logging and is set to the `ZERO_SHOT_REACT_DESCRIPTION` type, indicating its ability to understand and react to descriptions in a zero-shot learning context.

## **Custom Date Formatting Tool**

To enhance the agent's capabilities, we introduce a custom tool for date formatting (`DateFormatTool`). This tool ensures that dates mentioned in queries are correctly formatted to match the database schema ('YYYY-MM-DD 00:00:00'), allowing for precise data retrieval based on date-specific queries.

## **Example Query Execution**

We demonstrate the SQL agent's functionality by executing a sample query: retrieving yield curve values for a specific date. This showcases the agent's ability to translate natural language questions into SQL commands and fetch the relevant data from the database.



In [None]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType

# Create an SQL Agent
"""
create_sql_agent Function:
- Function Purpose: Initializes an SQL agent. The method creates an agent with specified properties.
- Parameters Explained:
  - llm: The language model being used, in this case, provided by OpenAI.
  - toolkit: The SQLDatabaseToolkit, which contains tools for query creation, execution, syntax checking, and more.
  - verbose: A boolean flag to enable detailed logging.
  - agent_type: Specifies the type of agent. Here 'ZERO_SHOT_REACT_DESCRIPTION' implies the agent's capability to understand and react to descriptions in a zero-shot learning context.
"""

agent_executor = create_sql_agent(
    llm=llm,
    toolkit=SQLDatabaseToolkit(db=db, llm=llm),
    verbose = True,
    agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    top_k=10

)

## Example query to query the SQL Agent
query_result = agent_executor.run(
    "List the yield curve values on 8/29/2006 00:00:00."
)

print("query result:", query_result)

# Close the connection
conn = sqlite3.connect(db_name)
conn.close()



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the yield_curve_prices table.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TEXT, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I should query the yield_curve_prices table 

In [None]:
from langchain.utilities import SQLDatabase
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType

agent_executor.run("Describe the data in the production_data table.")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the schema of the production_data table.
Action: sql_db_schema
Action Input: production_data[0m
Observation: [33;1m[1;3m
CREATE TABLE production_data (
	"Date" TEXT, 
	"SAUNGDPMOMBD" REAL, 
	"ARENGDPMOMBD" REAL, 
	"IRNNGDPMOMBD" REAL, 
	"SAUNXGO" REAL, 
	"DPCCRV1Q225SBEA" REAL, 
	"QATNGDPMOMBD" REAL, 
	"KAZNGDPMOMBD" REAL, 
	"IRQNXGO" REAL, 
	"IRNNXGO" REAL, 
	"KWTNGDPMOMBD" REAL, 
	"IPN213111S" REAL, 
	"PCU213111213111" REAL
)

/*
3 rows from production_data table:
Date	SAUNGDPMOMBD	ARENGDPMOMBD	IRNNGDPMOMBD	SAUNXGO	DPCCRV1Q225SBEA	QATNGDPMOMBD	KAZNGDPMOMBD	IRQNXGO	IRNNXGO	KWTNGDPMOMBD	IPN213111S	PCU213111213111
2001-01-01	7890000.0	2120000.0	3571700.0	7118438.3562	2.6	680000.0	806469.863	1453907.5448	2456136.2948	1745865.7534	116.469

"[('2001-01-01', 7890000.0, 2120000.0, 3571700.0, 7118438.3562, 2.6, 680000.0, 806469.863, 1453907.5448, 2456136.2948, 1745865.7534, 116.4692, 159.6), ('2001-01-02', 7890000.0, 2120000.0, 3571700.0, 7118438.3562, 2.6, 680000.0, 806469.863, 1453907.5448, 2456136.2948, 1745865.7534, 116.4692, 159.6), ('2001-01-03', 7890000.0, 2120000.0, 3571700.0, 7118438.3562, 2.6, 680000.0, 806469.863, 1453907.5448, 2456136.294"

## **Build a custom tool**

In [None]:
!pip install python-dateutil

In [None]:
import re
from datetime import datetime
from langchain.utilities import SQLDatabase
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.tools import AIPluginTool
from langchain.agents import Tool
from datetime import datetime
from dateutil.parser import parse
from dateutil import parser
from datetime import datetime
from langchain.agents import Tool

# Custom function to format date
def format_date_for_db(query_date: str):
    try:
        # Use dateutil.parser to parse the date from various formats
        date_obj = parser.parse(query_date)
        # Format the date to match the database format (e.g., 'YYYY-MM-DD 00:00:00')
        formatted_date = date_obj.strftime("%Y-%m-%d 00:00:00")
        return formatted_date
    except ValueError as e:
        # Handle the error if the date format is incorrect
        return f"Error: {e}"

# Create a custom tool for date formatting
date_format_tool = Tool(
    name="DateFormatTool",
    func=format_date_for_db,
    description="Formats the date to match the database format 'YYYY-MM-DD 00:00:00'."
)

# Example usage
formatted_date = date_format_tool.run("December 10th 2006")
print(formatted_date)  # Output: "2006-12-10 00:00:00"


2006-12-10 00:00:00


In [None]:
def format_date_for_db(query_date: str):
    try:
        # Use dateutil.parser to parse the date from various formats
        date_obj = parser.parse(query_date)
        # Format the date to match the database format (e.g., 'YYYY-MM-DD 00:00:00')
        formatted_date = date_obj.strftime("%Y-%m-%d 00:00:00")
        return formatted_date
    except ValueError as e:
        # Handle the error if the date format is incorrect
        return f"Error: {e}"


date_format_tool = Tool(name="DateFormatTool",func=format_date_for_db, description="Formats the date to match the database format 'YYYY-MM-DD 00:00:00'.")

In [None]:
import pandas as pd
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
from langchain.agents.agent_toolkits import create_retriever_tool
import openai

# Set your OpenAI API key
API_KEY = OPENAI_API_KEY
openai.api_key = API_KEY

# Initialize the SQLite database
db_uri = "sqlite:///financial_data.db"
db = SQLDatabase.from_uri(db_uri)

# Initialize the LangChain LLM with OpenAI
llm = OpenAI(openai_api_key=API_KEY, temperature=0, verbose=True)
tools = load_tools(["requests_all"])
# Define few-shot examples specific to your database
few_shots = {
    "What was the highest unemployment rate last year?": "SELECT MAX(UNRATE) FROM economic_indicators WHERE Date >= '2022-01-01 00:00:00' AND Date <= '2022-12-31 00:00:00';",
    "Average industrial production for the previous month?": "SELECT AVG(IPN213111S) FROM production_data WHERE Date >= date('now', 'start of month', '-1 month', '00:00:00') AND Date < date('now', 'start of month', '00:00:00');",
    "Show the five lowest 10-year yield rates of the current year.": "SELECT DGS10 FROM yield_curve_prices WHERE Date >= '2023-01-01 00:00:00' ORDER BY DGS10 ASC LIMIT 5;",
    "List the economic indicators for the first quarter of 2023.": "SELECT * FROM economic_indicators WHERE Date >= '2023-01-01 00:00:00' AND Date <= '2023-03-31 00:00:00';",
    "What are the latest available production numbers for Saudi Arabia?": "SELECT SAUNGDPMOMBD FROM production_data WHERE Date = (SELECT MAX(Date) FROM production_data);",
    "Compare the unemployment rate at the beginning and end of the last recession.": "SELECT UNRATE FROM economic_indicators WHERE Date IN (SELECT Start_Date FROM business_cycles WHERE Phase = 'Contraction' ORDER BY Start_Date DESC LIMIT 1) OR Date IN (SELECT End_Date FROM business_cycles WHERE Phase = 'Contraction' ORDER BY End_Date DESC LIMIT 1);",
    "Find the average civilian labor force participation rate for the last year.": "SELECT AVG(CIVPART) FROM economic_indicators WHERE Date >= '2022-01-01 00:00:00' AND Date <= '2022-12-31 00:00:00';",
    "Show the change in 2-year yield rates over the past six months.": "SELECT DGS2 FROM yield_curve_prices WHERE Date >= date('now', '-6 months', '00:00:00') ORDER BY Date;",
    "What was the maximum production of natural gas in Qatar last year?": "SELECT MAX(QATNGDPMOMBD) FROM production_data WHERE Date >= '2022-01-01 00:00:00' AND Date <= '2022-12-31 00:00:00';",
    "List the top 3 longest economic expansions since 2000.": "SELECT Start_Date, End_Date FROM business_cycles WHERE Phase = 'Expansion' AND Start_Date >= '2000-01-01 00:00:00' ORDER BY (julianday(End_Date) - julianday(Start_Date)) DESC LIMIT 3;"
}


# Create a retriever for few-shot examples
embeddings = OpenAIEmbeddings(openai_api_key=API_KEY)
few_shot_docs = [Document(page_content=question, metadata={"sql_query": few_shots[question]}) for question in few_shots.keys()]
vector_db = FAISS.from_documents(few_shot_docs, embeddings)
retriever = vector_db.as_retriever()

# Create a custom tool for retrieving similar examples
retriever_tool = create_retriever_tool(retriever, name="sql_get_similar_examples", description="Retrieves similar SQL examples, use the YYYY/MM/DD 00:00:00 datetime format.")

# Create the SQL Agent with the custom tool
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    extra_tools=[retriever_tool, date_format_tool],
    top_k=30
)

# Example query using the SQL Agent
query_result = agent_executor.run(
     "What was the 30-Year treasury yield on December 20th 2021?"
)
print("Query Result:", query_result)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input:
[0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the yield_curve_prices table to get the 30-Year treasury yield.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TEXT, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I should u

In [None]:
%%writefile Return_Empty_Result.py
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.language_models import BaseLanguageModel
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator

INTERMEDIATE_STEPS_KEY = "intermediate_steps"
SQL_QUERY = "SQLQuery:"

class SQLDatabaseChain(Chain):
    llm_chain: LLMChain
    llm: Optional[BaseLanguageModel] = None
    database: SQLDatabase = Field(exclude=True)
    prompt: Optional[BasePromptTemplate] = None
    top_k: int = 5
    input_key: str = "query"
    output_key: str = "result"
    return_sql: bool = False
    return_intermediate_steps: bool = False
    return_direct: bool = False
    use_query_checker: bool = False
    query_checker_prompt: Optional[BasePromptTemplate] = None

    class Config:
        extra = Extra.forbid
        arbitrary_types_allowed = True

    @root_validator(pre=True)
    def raise_deprecation(cls, values: Dict) -> Dict:
        if "llm" in values:
            warnings.warn(
                "Directly instantiating an SQLDatabaseChain with an llm is deprecated. "
                "Please instantiate with llm_chain argument or using the from_llm "
                "class method."
            )
            if "llm_chain" not in values and values["llm"] is not None:
                database = values["database"]
                prompt = values.get("prompt") or SQL_PROMPTS.get(
                    database.dialect, PROMPT
                )
                values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
        return values

    @property
    def input_keys(self) -> List[str]:
        return [self.input_key]

    @property
    def output_keys(self) -> List[str]:
        if not self.return_intermediate_steps:
            return [self.output_key]
        else:
            return [self.output_key, INTERMEDIATE_STEPS_KEY]

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        input_text = f"{inputs[self.input_key]}\n{SQL_QUERY}"
        _run_manager.on_text(input_text, verbose=self.verbose)
        table_names_to_use = inputs.get("table_names_to_use")
        table_info = self.database.get_table_info(table_names=table_names_to_use)
        llm_inputs = {
            "input": input_text,
            "top_k": str(self.top_k),
            "dialect": self.database.dialect,
            "table_info": table_info,
            "stop": ["\nSQLResult:"],
        }
        if self.memory is not None:
            for k in self.memory.memory_variables:
                llm_inputs[k] = inputs[k]
        intermediate_steps: List = []
        try:
            intermediate_steps.append(llm_inputs.copy())  # input: sql generation
            sql_cmd = self.llm_chain.predict(
                callbacks=_run_manager.get_child(),
                **llm_inputs,
            ).strip()
            if self.return_sql:
                return {self.output_key: sql_cmd}
            if not self.use_query_checker:
                _run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
                intermediate_steps.append(sql_cmd)  # output: sql generation (no checker)
                intermediate_steps.append({"sql_cmd": sql_cmd})  # input: sql exec
                if SQL_QUERY in sql_cmd:
                    sql_cmd = sql_cmd.split(SQL_QUERY)[1].strip()
                result = self.database.run(sql_cmd)
                intermediate_steps.append(str(result))  # output: sql exec
            else:
                query_checker_prompt = self.query_checker_prompt or PromptTemplate(
                    template=QUERY_CHECKER, input_variables=["query", "dialect"]
                )
                query_checker_chain = LLMChain(
                    llm=self.llm_chain.llm, prompt=query_checker_prompt
                )
                query_checker_inputs = {
                    "query": sql_cmd,
                    "dialect": self.database.dialect,
                }
                checked_sql_command: str = query_checker_chain.predict(
                    callbacks=_run_manager.get_child(), **query_checker_inputs
                ).strip()
                intermediate_steps.append(checked_sql_command)  # output: sql generation (checker)
                _run_manager.on_text(
                    checked_sql_command, color="green", verbose=self.verbose
                )
                intermediate_steps.append(
                    {"sql_cmd": checked_sql_command}
                )  # input: sql exec
                result = self.database.run(checked_sql_command)
                intermediate_steps.append(str(result))  # output: sql exec
                sql_cmd = checked_sql_command

            _run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
            _run_manager.on_text(result, color="yellow", verbose=self.verbose)
            if self.return_direct:
                final_result = result
            else:
                _run_manager.on_text("\nAnswer:", verbose=self.verbose)
                input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
                llm_inputs["input"] = input_text
                intermediate_steps.append(llm_inputs.copy())  # input: final answer
                final_result = self.llm_chain.predict(
                    callbacks=_run_manager.get_child(),
                    **llm_inputs,
                ).strip()
                intermediate_steps.append(final_result)  # output: final answer
                _run_manager.on_text(final_result, color="green", verbose=self.verbose)
            chain_result: Dict[str, Any] = {self.output_key: final_result}
            if self.return_intermediate_steps:
                chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
            return chain_result
        except Exception as exc:
            exc.intermediate_steps = intermediate_steps  # type: ignore
            raise exc

    @property
    def _chain_type(self) -> str:
        return "sql_database_chain"

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        db: SQLDatabase,
        prompt: Optional[BasePromptTemplate] = None,
        **kwargs: Any,
    ) -> SQLDatabaseChain:
        prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT)
        llm_chain = LLMChain(llm=llm, prompt=prompt)
        return cls(llm_chain=llm_chain, database=db, **kwargs)


Writing Return_Empty_Result.py


##**Creation of the vector database with FewShot Examples**

In [None]:
import sqlite3
import pandas as pd
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
#from langchain_experimental.sql import SQLDatabaseChain
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
from langchain.agents.agent_toolkits import create_retriever_tool
import openai
import sqlite3
import pandas as pd
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase

# Set your OpenAI API key
API_KEY=OPENAI_API_KEY
openai.api_key = API_KEY

# Initialize the SQLite database
db_uri = "sqlite:///financial_data.db"
db = SQLDatabase.from_uri(db_uri)

# Initialize the LangChain LLM with OpenAI
llm = OpenAI(openai_api_key=API_KEY, temperature=0, verbose=True)

# Define few-shot examples specific to your database
few_shots = {
    "What was the highest unemployment rate last year?": "SELECT MAX(UNRATE) FROM economic_indicators WHERE Date >= '2022-01-01 00:00:00' AND Date <= '2022-12-31 00:00:00';",
    "Average industrial production for the previous month?": "SELECT AVG(IPN213111S) FROM production_data WHERE Date >= date('now', 'start of month', '-1 month') AND Date < date('now', 'start of month');",
    "Show the five lowest 10-year yield rates of the current year.": "SELECT * FROM yield_curve_prices WHERE Date >= '2023-01-01 00:00:00' ORDER BY DGS10 ASC LIMIT 5;",
    "List the economic indicators for the first quarter of 2023.": "SELECT * FROM economic_indicators WHERE Date >= '2023-01-01 00:00:00' AND Date <= '2023-03-31 00:00:00';",
    "What are the latest available production numbers for Saudi Arabia?": "SELECT SAUNGDPMOMBD FROM production_data ORDER BY Date DESC LIMIT 1;",
    "Compare the unemployment rate at the beginning and end of the last recession.": "SELECT UNRATE FROM economic_indicators WHERE Date IN (SELECT Start_Date FROM business_cycles WHERE Phase = 'Contraction' ORDER BY Start_Date DESC LIMIT 1) OR Date IN (SELECT End_Date FROM business_cycles WHERE Phase = 'Contraction' ORDER BY End_Date DESC LIMIT 1);",
    "Find the average civilian labor force participation rate for the last year.": "SELECT AVG(CIVPART) FROM economic_indicators WHERE Date >= '2022-01-01 00:00:00' AND Date <= '2022-12-31 00:00:00';",
    "Show the change in 2-year yield rates over the past six months.": "SELECT DGS2 FROM yield_curve_prices WHERE Date >= date('now', '-6 months') ORDER BY Date;",
    "What was the maximum production of natural gas in Qatar last year?": "SELECT MAX(QATNGDPMOMBD) FROM production_data WHERE Date >= '2022-01-01 00:00:00' AND Date <= '2022-12-31 00:00:00';",
    "List the top 3 longest economic expansions since 2000.": "SELECT Start_Date, End_Date FROM business_cycles WHERE Phase = 'Expansion' AND Start_Date >= '2000-01-01 00:00:00' ORDER BY (julianday(End_Date) - julianday(Start_Date)) DESC LIMIT 3;"
}

# Create a retriever for few-shot examples
embeddings = OpenAIEmbeddings(openai_api_key=API_KEY)
few_shot_docs = [Document(page_content=question, metadata={"sql_query": few_shots[question]}) for question in few_shots.keys()]
vector_db = FAISS.from_documents(few_shot_docs, embeddings)
retriever = vector_db.as_retriever()

# Create a custom tool for retrieving similar examples
retriever_tool = create_retriever_tool(retriever, name="sql_get_relevant_examples", description="Retrieves similar SQL examples, please use the 00:00:00 datetime format.")

# Create the SQL Agent with the custom tool
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    extra_tools=[retriever_tool, date_format_tool],
    top_k=30
)

# Example query using the SQL Agent
query_result = agent_executor.run(
    "What was the 20-Year treasury yield on January 10th 2006?"
)
print("Query Result:", query_result)

# Close the database connection
conn = sqlite3.connect(db_name)
conn.close()



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input:
[0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the yield_curve_prices table to get the 20-Year treasury yield.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TEXT, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I should q

In [None]:
import sqlite3
import pandas as pd
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.memory import ConversationEntityMemory
from langchain.memory.entity import InMemoryEntityStore
import openai

# Set your OpenAI API key
API_KEY = OPENAI_API_KEY
openai.api_key = API_KEY

# Initialize the SQLite database
db_uri = "sqlite:///financial_data.db"
db = SQLDatabase.from_uri(db_uri)

# Initialize the LangChain LLM with OpenAI
llm = OpenAI(openai_api_key=API_KEY, temperature=0.1, verbose=True)

# Initialize InMemoryEntityStore for memory
entity_store = InMemoryEntityStore()

# Initialize ConversationEntityMemory with the LLM and entity store
memory = ConversationEntityMemory(llm=llm, entity_store=entity_store)

# Create the SQL Agent
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    top_k=5
)

# Example query using the SQL Agent with Memory
query_result = agent_executor.run(
    input="What was the 3 month yield curve rate on 2007-01-25 00:00:00?",
    memory=memory  # Passing the memory to the agent
)

print("Query Result:", query_result)

  warn_deprecated(




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input:
[0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the yield_curve_prices table to get the 3 month yield curve rate.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TEXT, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I should

In [None]:
# Subsequent queries can also utilize the memory
another_query_result = agent_executor.run(
    input="What does the average shape of the yield curve in 2023 imply about consumer spending?",
    memory=memory
)

print("Another Query Result:", another_query_result)

# Close the database connection
conn = sqlite3.connect("financial_data.db")
conn.close()



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the schema of the yield_curve_prices table.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TIMESTAMP, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I should query the averag

In [None]:
# Subsequent queries can also utilize the memory
another_query_result = agent_executor.run(
    input="How does the current shape of the yield curve impact industrial production?",
    memory=memory
)

print("Another Query Result:", another_query_result)

# Close the database connection
conn = sqlite3.connect("financial_data.db")
conn.close()



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input:
[0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the schema of the production_data and yield_curve_prices tables.
Action: sql_db_schema
Action Input: production_data, yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE production_data (
	"Date" TIMESTAMP, 
	"SAUNGDPMOMBD" REAL, 
	"ARENGDPMOMBD" REAL, 
	"IRNNGDPMOMBD" REAL, 
	"SAUNXGO" REAL, 
	"DPCCRV1Q225SBEA" REAL, 
	"QATNGDPMOMBD" REAL, 
	"KAZNGDPMOMBD" REAL, 
	"IRQNXGO" REAL, 
	"IRNNXGO" REAL, 
	"KWTNGDPMOMBD" REAL, 
	"IPN213111S" REAL, 
	"PCU213111213111" REAL
)

/*
3 rows from production_data table:
Date	SAUNGDPMOMBD	ARENGDPMOMBD	IRNNGDPMOMBD	SAUNXGO	DPCCRV1Q225SBEA	QATNGDPMOMBD	KAZNGDPMOMBD	IRQNXGO	IRNNXGO	KWTNGDPMOMBD	IPN213111S	PCU213111213111
2001-01-01 00:00:00	7890000.0	2120000.0	3571700.0	7118438.3562	2.6	680000.0

In [None]:
# Example query using the SQL Agent with memory
query = "Describe the average shape of the yield curve in 2023, how does it compare to 2022?"
query_result = agent_executor.run(input=query, memory=memory)
print("Query Result:", query_result)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the schema of the yield_curve_prices table.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TEXT, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I should query the average sha

In [None]:
# Example query using the SQL Agent with memory
query = "Describe the average shape of the yield curve in 2023, are short term rates higher than long term rates?"
query_result = agent_executor.run(input=query, memory=memory)
print("Query Result:", query_result)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the yield_curve_prices table to get the average shape of the yield curve in 2023.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TEXT, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1

In [None]:
# Example query using the SQL Agent with memory
query = "Given your previous response, on average was the yield curve inverted in 2023"
query_result = agent_executor.run(input=query, memory=memory)
print("Query Result:", query_result)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input:
[0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the schema of the yield_curve_prices table.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TEXT, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I should query the yield_curve

In [None]:
import sqlite3
import pandas as pd
import openai
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit, create_retriever_tool
from langchain.agents.agent_types import AgentType
from langchain.memory import ConversationEntityMemory
from langchain.memory.entity import InMemoryEntityStore
from langchain.memory.chat_message_histories import SQLChatMessageHistory
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import uuid
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
import redis
import openai
from dotenv import load_dotenv
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain_experimental.sql import SQLDatabaseChain
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.memory.chat_message_histories import SQLChatMessageHistory
from langchain.prompts import PromptTemplate
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain.utilities import GoogleSearchAPIWrapper
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent
from langchain.utilities import SerpAPIWrapper
from langchain.utilities import SQLDatabase
from langchain.memory.chat_message_histories import SQLChatMessageHistory
from langchain_experimental.sql import SQLDatabaseChain
from langchain_experimental.sql.base import SQLDatabaseSequentialChain
import os
import uuid
from datetime import datetime, timedelta
import math
import sqlite3
from Return_Empty_Result import SQLDatabaseChain

# Load environment variables from .env file
load_dotenv()

# Load environment variables from .env file
load_dotenv()

serpapi_key = os.getenv("SERPAPI_API_KEY")

def generate_session_id(interval_minutes=5):
    try:
        now = datetime.utcnow()  # Get the current UTC time
        minutes_since_midnight = now.hour * 60 + now.minute  # Calculate minutes since midnight
        interval_count = math.floor(minutes_since_midnight / interval_minutes)  # Calculate the interval count
        session_id = now.strftime("%Y%m%d") + f"_{interval_count}"
        return session_id
    except Exception as e:
        print(f"Error in generating session ID: {e}")  # Handle exceptions and print an error message
        return None  # Return None in case of an error

def run_agent_query(agent_executor, query):
    try:
        return agent_executor.run(input=query)
    except openai.BadRequestError as e:
        print(f"Token limit exceeded: {e}")
        return "Sorry, the query is too long to process."
    except Exception as e:
        print(f"An unexpected error occurred: {type(e).__name__}, {str(e)}")
        return f"An error occurred: {str(e)}"

session_id = generate_session_id()
# Initialize InMemoryEntityStore for memory
entity_store = InMemoryEntityStore()
# Initialize ConversationEntityMemory with the LLM and entity store
memory = ConversationEntityMemory(llm=llm, entity_store=entity_store)

# Initialize the SQLite database
db_uri = "sqlite:///financial_data.db"
db = SQLDatabase.from_uri(db_uri)
sql_tool = Tool(name="SQL", func=retriever_tool.run, description="Retrieves similar SQL examples, please use the 00:00:00 datetime format.")

# Add both the SerpAPI and SQL tools to your list of tools
tools = [sql_tool]

# Use ZeroShotAgent to create the prompt
prefix = """
As an economic data analyst and a SQLite expert, I specialize in interpreting economic indicators and performing data analysis using specific SQL queries.
My expertise includes the following tables and their respective fields, you also have access to these tools:

1. economic_indicators (DATE, UNRATE, PAYEMS, ICSA, CIVPART, INDPRO)
2. yield_curve_prices (Date, DGS1MO, DGS3MO, DGS6MO, DGS1, DGS2, DGS3, DGS5, DGS7, DGS10, DGS20, DGS30)
3. production_data (Date, SAUNGDPMOMBD, ARENGDPMOMBD, IRNNGDPMOMBD, SAUNXGO, QATNGDPMOMBD, KAZNGDPMOMBD, IRQNXGO, IRNNXGO, KWTNGDPMOMBD, IPN213111S, PCU213111213111, DPCCRV1Q225SBEA)
4. business_cycles (Peak_Month, Trough_Month, Start_Date, End_Date, Phase)

Please ask me questions related to these tables and fields. I will use SQL queries and web searches to provide accurate information on economic trends, indicators, and analysis based on this data.
Note that my responses are constrained to the data available in these tables. If I do not know an answer, I will say I do not know.
"""

suffix = """
Begin!

{chat_history}
Question: {input}
{agent_scratchpad}
"""
# Initialize the LangChain LLM with OpenAI
llm = OpenAI(openai_api_key=API_KEY, temperature=0, verbose=True)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=False, use_query_checker=True, return_intermediate_steps=False, top_k = 5)
prompt = ZeroShotAgent.create_prompt(tools=tools, prefix=prefix, suffix=suffix, input_variables=["input", "chat_history", "agent_scratchpad"])
llm_chain = LLMChain(llm=llm, prompt=prompt)

# Define few-shot examples and create retriever
few_shots = {
    "What was the highest unemployment rate last year?": "SELECT MAX(UNRATE) FROM economic_indicators WHERE Date >= '2022-01-01' AND Date <= '2022-12-31';",
    "Average industrial production for the previous month?": "SELECT AVG(IPN213111S) FROM production_data WHERE Date >= date('now', 'start of month', '-1 month') AND Date < date('now', 'start of month');",
    "Show the five lowest 10-year yield rates of the current year.": "SELECT * FROM yield_curve_prices WHERE Date >= '2023-01-01' ORDER BY DGS10 ASC LIMIT 5;",
    "List the economic indicators for the first quarter of 2023.": "SELECT * FROM economic_indicators WHERE Date >= '2023-01-01' AND Date <= '2023-03-31';",
    "What are the latest available production numbers for Saudi Arabia?": "SELECT SAUNGDPMOMBD FROM production_data ORDER BY Date DESC LIMIT 1;",
    "Compare the unemployment rate at the beginning and end of the last recession.": "SELECT UNRATE FROM economic_indicators WHERE Date IN (SELECT Start_Date FROM business_cycles WHERE Phase = 'Contraction' ORDER BY Start_Date DESC LIMIT 1) OR Date IN (SELECT End_Date FROM business_cycles WHERE Phase = 'Contraction' ORDER BY End_Date DESC LIMIT 1);",
    "Find the average civilian labor force participation rate for the last year.": "SELECT AVG(CIVPART) FROM economic_indicators WHERE Date >= '2022-01-01' AND Date <= '2022-12-31';",
    "Show the change in 2-year yield rates over the past six months.": "SELECT DGS2 FROM yield_curve_prices WHERE Date >= date('now', '-6 months') ORDER BY Date;",
    "What was the maximum production of natural gas in Qatar last year?": "SELECT MAX(QATNGDPMOMBD) FROM production_data WHERE Date >= '2022-01-01' AND Date <= '2022-12-31';",
    "List the top 3 longest economic expansions since 2000.": "SELECT Start_Date, End_Date FROM business_cycles WHERE Phase = 'Expansion' AND Start_Date >= '2000-01-01' ORDER BY (julianday(End_Date) - julianday(Start_Date)) DESC LIMIT 3;"
}
embeddings = OpenAIEmbeddings(openai_api_key=API_KEY)
few_shot_docs = [Document(page_content=question, metadata={"sql_query": few_shots[question]}) for question in few_shots.keys()]
vector_db = FAISS.from_documents(few_shot_docs, embeddings)
retriever = vector_db.as_retriever()
tool_description = """
This tool will help you get similar SQL examples, please use the 00:00:00 datetime format to adapt them to the user question.
Input to this tool should be the user question.
"""
# Create a custom tool for retrieving similar examples
retriever_tool = create_retriever_tool(retriever, name="sql_get_similar_examples", description=tool_description)

# Initialize ConversationEntityMemory and chat history
memory = ConversationBufferMemory(memory_key="chat_history")
random_session_id = str(uuid.uuid4())
engine = create_engine(db_uri)
Session = sessionmaker(bind=engine)
session = Session()
chat_message_history = SQLChatMessageHistory(session_id=random_session_id, connection_string=db_uri)

# Create the SQL Agent with memory, chat history, and retriever tool
toolkit = SQLDatabaseToolkit(db=db, llm=llm, chat_message_history=chat_message_history)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    extra_tools=[retriever_tool],
    top_k=5
)
agent = ZeroShotAgent(llm_chain=llm_chain, tools=agent_executor, verbose=True)
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory,  handle_parsing_errors=True)
# Add a user message to the chat history and execute query
query = "Describe the yield curve on August 10th 2006. Was the curve inverted?"
query_result = agent_executor.run(input=query)

# Print the query result
print("Query Result:", query_result)

  warn_deprecated(
  warn_deprecated(




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input:
[0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, message_store, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the yield_curve_prices table to get the yield curve data.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TEXT, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I

In [None]:
import sqlite3
import pandas as pd
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.memory import ConversationEntityMemory
from langchain.memory.entity import InMemoryEntityStore
from langchain.memory.chat_message_histories import SQLChatMessageHistory
import openai
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from datetime import datetime
from sqlalchemy import Column, DateTime, Integer, Text
from sqlalchemy.ext.declarative import declarative_base
from langchain.memory.chat_message_histories.sql import BaseMessageConverter
from langchain.schema import HumanMessage, AIMessage, BaseMessage, SystemMessage
import uuid

# Define the CustomMessage class and CustomMessageConverter
Base = declarative_base()

class CustomMessage(Base):
    __tablename__ = "custom_message_store"
    id = Column(Integer, primary_key=True)
    session_id = Column(Text)
    type = Column(Text)
    content = Column(Text)
    created_at = Column(DateTime)
    author_email = Column(Text)

Base = declarative_base()

class CustomMessage(Base):
    __tablename__ = "custom_message_store"
    id = Column(Integer, primary_key=True)
    session_id = Column(Text)
    type = Column(Text)
    content = Column(Text)
    created_at = Column(DateTime)
    author_email = Column(Text)

class CustomMessageConverter(BaseMessageConverter):
    def __init__(self, author_email: str):
        self.author_email = author_email

    def from_sql_model(self, sql_message: CustomMessage) -> BaseMessage:
        if sql_message.type == "human":
            return HumanMessage(content=sql_message.content)
        elif sql_message.type == "ai":
            return AIMessage(content=sql_message.content)
        elif sql_message.type == "system":
            return SystemMessage(content=sql_message.content)
        else:
            raise ValueError(f"Unknown message type: {sql_message.type}")

    def to_sql_model(self, message: BaseMessage, session_id: str) -> CustomMessage:
        now = datetime.now()
        return CustomMessage(
            session_id=session_id,
            type=message.type,
            content=message.content,
            created_at=now,
            author_email=self.author_email,
        )

    def get_sql_model_class(self) -> CustomMessage:
        return CustomMessage

# Define your database connection
connection_string = "sqlite:///financial_data.db"
engine = create_engine(connection_string)
Session = sessionmaker(bind=engine)
session = Session()

# Initialize chat history with custom message converter
chat_message_history = SQLChatMessageHistory(
    session_id=str(uuid.uuid4()),
    connection_string=connection_string,
    custom_message_converter=CustomMessageConverter(author_email="your_email@example.com"),
)
# Create tables
Base.metadata.create_all(engine)

# Initialize the SQLite database
db = SQLDatabase.from_uri(connection_string)

# Initialize ConversationEntityMemory with the LLM
memory = ConversationEntityMemory(llm=llm, entity_store=InMemoryEntityStore())

# Create the SQL Agent
toolkit = SQLDatabaseToolkit(db=db, llm=llm, chat_message_history=chat_message_history)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    top_k=10
)

# Example query using the SQL Agent with memory
query = "What was the 3 month yield curve rate on 2022-01-25 00:00:00?"
query_result = agent_executor.run(input=query, memory=memory)
print("Query Result:", query_result)

  Base = declarative_base()
  Base = declarative_base()




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input:
[0m
Observation: [38;5;200m[1;3mbusiness_cycles, custom_message_store, economic_indicators, message_store, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the yield_curve_prices table.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TIMESTAMP, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I 

In [None]:
# Example query using the SQL Agent with memory
query = "Describe the average shape of the yield curve in 2023?"
query_result = agent_executor.run(input=query, memory=memory)
print("Query Result:", query_result)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mbusiness_cycles, custom_message_store, economic_indicators, message_store, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the yield_curve_prices table to get the average shape of the yield curve.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TIMESTAMP, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.

In [None]:
# Example query using the SQL Agent with memory
query = "How about 2024?"
query_result = agent_executor.run(input=query, memory=memory)
print("Query Result:", query_result)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mbusiness_cycles, custom_message_store, economic_indicators, message_store, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the schema of the production_data table.
Action: sql_db_schema
Action Input: production_data[0m
Observation: [33;1m[1;3m
CREATE TABLE production_data (
	"Date" TIMESTAMP, 
	"SAUNGDPMOMBD" REAL, 
	"ARENGDPMOMBD" REAL, 
	"IRNNGDPMOMBD" REAL, 
	"SAUNXGO" REAL, 
	"DPCCRV1Q225SBEA" REAL, 
	"QATNGDPMOMBD" REAL, 
	"KAZNGDPMOMBD" REAL, 
	"IRQNXGO" REAL, 
	"IRNNXGO" REAL, 
	"KWTNGDPMOMBD" REAL, 
	"IPN213111S" REAL, 
	"PCU213111213111" REAL
)

/*
3 rows from production_data table:
Date	SAUNGDPMOMBD	ARENGDPMOMBD	IRNNGDPMOMBD	SAUNXGO	DPCCRV1Q225SBEA	QATNGDPMOMBD	KAZNGDPMOMBD	IRQNXGO	IRNNXGO	KWTNGDPMOMBD	IPN213111S	PCU213111213111
2001-01-01 00:00:00	7890000.0	2120000.0	3571700.0	7118438.3562	2.6	680000.0	806469

In [None]:
%%writefile setup_database.py
import sqlite3

def create_tables(db_name):
    with sqlite3.connect(db_name) as conn:
        cursor = conn.cursor()

        # Create economic indicators table
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS economic_indicators (
                DATE TEXT PRIMARY KEY,
                UNRATE REAL,
                PAYEMS REAL,
                ICSA REAL,
                CIVPART REAL,
                INDPRO REAL
            )
        ''')

        # Create yield_curve_prices table
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS yield_curve_prices (
                Date TEXT PRIMARY KEY,
                DGS1MO REAL,
                DGS3MO REAL,
                DGS6MO REAL,
                DGS1 REAL,
                DGS2 REAL,
                DGS3 REAL,
                DGS5 REAL,
                DGS7 REAL,
                DGS10 REAL,
                DGS20 REAL,
                DGS30 REAL
            )
        ''')

        # Modify production_data table to match yield_curve_prices schema
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS production_data (
                Date TEXT PRIMARY KEY,
                SAUNGDPMOMBD REAL,
                ARENGDPMOMBD REAL,
                IRNNGDPMOMBD REAL,
                SAUNXGO REAL,
                QATNGDPMOMBD REAL,
                KAZNGDPMOMBD REAL,
                IRQNXGO REAL,
                IRNNXGO REAL,
                KWTNGDPMOMBD REAL,
                IPN213111S REAL,
                PCU213111213111 REAL,
                DPCCRV1Q225SBEA REAL
            )
        ''')

        # Create business_cycles table
        cursor.execute('''
            CREATE TABLE business_cycles (
                Peak_Month TEXT,
                Trough_Month TEXT,
                Start_Date TEXT,
                End_Date TEXT,
                Phase TEXT
            )
        ''')

        print("Tables created successfully.")

def insert_business_cycle_data(db_name):
    business_cycles = [
        # [Insert business cycle data here as in your provided code]
    ]
    with sqlite3.connect(db_name) as conn:
        cursor = conn.cursor()
        for cycle in business_cycles:
            cursor.execute('''
                INSERT INTO business_cycles (Peak_Month, Trough_Month, Start_Date, End_Date, Phase)
                VALUES (?, ?, ?, ?, ?)
            ''', (cycle["peak"], cycle["trough"], cycle["start"], cycle["end"], cycle["phase"]))
        print("Business cycle data inserted successfully.")

if __name__ == "__main__":
    create_tables("financial_data.db")
    insert_business_cycle_data("financial_data.db")

In [None]:
%%writefile fetch_insert_data.py
import pandas as pd
import pandas_datareader.data as web
import sqlite3
from datetime import datetime

class DataLoader:
    def __init__(self, db_name="financial_data.db"):
        self.db_name = db_name
        self.economic_indicators_tickers = ['UNRATE', 'PAYEMS', 'ICSA', 'CIVPART', 'INDPRO']
        self.yield_curve_tickers = ['DGS1MO', 'DGS3MO', 'DGS6MO', 'DGS1', 'DGS2', 'DGS3', 'DGS5', 'DGS7', 'DGS10', 'DGS20', 'DGS30']
        self.production_data_tickers = ['SAUNGDPMOMBD', 'ARENGDPMOMBD', 'IRNNGDPMOMBD', 'SAUNXGO', 'DPCCRV1Q225SBEA',
                                        'QATNGDPMOMBD', 'KAZNGDPMOMBD', 'IRQNXGO', 'IRNNXGO', 'KWTNGDPMOMBD', 'IPN213111S', 'PCU213111213111']

    def fetch_and_insert_data(self, tickers, table_name):
        start_date = '2000-12-31'
        end_date = datetime.now().strftime('%Y-%m-%d')
        try:
            data = web.DataReader(tickers, 'fred', start_date, end_date)
            data = data.interpolate(method='quadratic').bfill().ffill()
            data = data.resample('D').ffill().bfill()

            with sqlite3.connect(self.db_name) as conn:
                data.to_sql(table_name, conn, if_exists='replace', index_label='Date')
            print(f"Data inserted into {table_name} table")
        except Exception as e:
            print(f"Failed to fetch and insert the data: {e}")

def main():
    db_name = "financial_data.db"
    loader = DataLoader(db_name)
    loader.fetch_and_insert_data(loader.economic_indicators_tickers, 'economic_indicators')
    loader.fetch_and_insert_data(loader.yield_curve_tickers, 'yield_curve_prices')
    loader.fetch_and_insert_data(loader.production_data_tickers, 'production_data')

if __name__ == "__main__":
    main()

In [None]:
%%writefile initialize_agent.py
import redis
import openai
from dotenv import load_dotenv
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain_experimental.sql import SQLDatabaseChain
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.memory.chat_message_histories import SQLChatMessageHistory
import os
import uuid
from langchain.prompts import PromptTemplate
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain.utilities import GoogleSearchAPIWrapper
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent
from langchain.utilities import SerpAPIWrapper

# Load environment variables from .env file
load_dotenv()

# Set the environment variable
os.environ["SERPAPI_API_KEY"] = "184cfe6dfee9e06b51ed67cd2df6deda3c566d0e797a3661b7a6c5d8ce1431d4"
serpapi_key = os.getenv("SERPAPI_API_KEY")

def initialize_agent(api_key, db_uri):
    # Set OpenAI API key
    api_key = os.getenv('OPENAI_API_KEY')
    openai.api_key = api_key

    # Initialize Redis Memory
    session_id = str(uuid.uuid4())
    message_history = RedisChatMessageHistory(url="redis://172.17.0.2:6379/0", ttl=600, session_id=session_id)
    memory = ConversationBufferMemory(memory_key="chat_history", chat_memory=message_history)

    # Initialize the SerpAPI Wrapper with your API key
    search = SerpAPIWrapper(serpapi_api_key=serpapi_key)
    serpapi_tool = Tool(name="Search", func=search.run, description="useful for when you need to answer questions about the yield curve, finance or economics")

    # Initialize the SQLite database
    db = SQLDatabase.from_uri(db_uri)

    # Define few-shot examples specific to your database
    few_shots = {
        "What was the highest unemployment rate last year?": "SELECT MAX(UNRATE) FROM economic_indicators WHERE Date >= '2022-01-01' AND Date <= '2022-12-31';",
        "Average industrial production for the previous month?": "SELECT AVG(IPN213111S) FROM production_data WHERE Date >= date('now', 'start of month', '-1 month') AND Date < date('now', 'start of month');",
        "Show the five lowest 10-year yield rates of the current year.": "SELECT * FROM yield_curve_prices WHERE Date >= '2023-01-01' ORDER BY DGS10 ASC LIMIT 5;",
        "List the economic indicators for the first quarter of 2023.": "SELECT * FROM economic_indicators WHERE Date >= '2023-01-01' AND Date <= '2023-03-31';",
        "What are the latest available production numbers for Saudi Arabia?": "SELECT SAUNGDPMOMBD FROM production_data ORDER BY Date DESC LIMIT 1;",
        "Compare the unemployment rate at the beginning and end of the last recession.": "SELECT UNRATE FROM economic_indicators WHERE Date IN (SELECT Start_Date FROM business_cycles WHERE Phase = 'Contraction' ORDER BY Start_Date DESC LIMIT 1) OR Date IN (SELECT End_Date FROM business_cycles WHERE Phase = 'Contraction' ORDER BY End_Date DESC LIMIT 1);",
        "Find the average civilian labor force participation rate for the last year.": "SELECT AVG(CIVPART) FROM economic_indicators WHERE Date >= '2022-01-01' AND Date <= '2022-12-31';",
        "Show the change in 2-year yield rates over the past six months.": "SELECT DGS2 FROM yield_curve_prices WHERE Date >= date('now', '-6 months') ORDER BY Date;",
        "What was the maximum production of natural gas in Qatar last year?": "SELECT MAX(QATNGDPMOMBD) FROM production_data WHERE Date >= '2022-01-01' AND Date <= '2022-12-31';",
        "List the top 3 longest economic expansions since 2000.": "SELECT Start_Date, End_Date FROM business_cycles WHERE Phase = 'Expansion' AND Start_Date >= '2000-01-01' ORDER BY (julianday(End_Date) - julianday(Start_Date)) DESC LIMIT 3;"
    }

    # Create a retriever for few-shot examples
    embeddings = OpenAIEmbeddings(openai_api_key=api_key)
    few_shot_docs = [Document(page_content=question, metadata={"sql_query": few_shots[question]}) for question in few_shots.keys()]
    vector_db = FAISS.from_documents(few_shot_docs, embeddings)
    retriever = vector_db.as_retriever()

    # Create a custom tool for retrieving similar examples
    retriever_tool = create_retriever_tool(retriever, name="sql_get_similar_examples", description="Retrieves similar SQL examples.")

    # Define the SQL tool
    sql_tool = Tool(name="SQL", func=retriever_tool.run, description="useful for querying financial data from the database")

    # Add both the SerpAPI and SQL tools to your list of tools
    tools = [serpapi_tool, sql_tool]

    # Use ZeroShotAgent to create the prompt
    prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
    suffix = """Begin!\n\n{chat_history}\nQuestion: {input}\n{agent_scratchpad}"""
    prompt = ZeroShotAgent.create_prompt(tools=tools, prefix=prefix, suffix=suffix, input_variables=["input", "chat_history", "agent_scratchpad"])

    # Initialize the LangChain LLM with OpenAI
    llm = OpenAI(openai_api_key=api_key, temperature=0.2, verbose=True)
    llm_chain = LLMChain(llm=llm, prompt=prompt)

    agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
    agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory)

    return agent_chain

if __name__ == "__main__":
    API_KEY = os.getenv('OPENAI_API_KEY')
    db_uri = "sqlite:///financial_data.db"
    agent_executor = initialize_agent(API_KEY, db_uri)

    # Example query using the agent
    query_result = agent_executor.run(input="Based on the yield curve, did the economy appear to be in an expansionary phase in 2017, and why?")
    print("Query Result:", query_result)
