# Medium LangChain Demo

This repository contains the Jupyter notebook `Medium_LangChain_Demo.ipynb`, which is part of a Medium article series focusing on leveraging LangChain with SQL for enhanced financial data analysis. The notebook demonstrates the integration of LangChain with SQL databases, showing how to interact with SQL databases using natural language queries.

## Overview

The `Medium_LangChain_Demo.ipynb` notebook is designed for AI architects and ML engineers interested in exploring the capabilities of LangChain in conjunction with SQL databases. It is particularly focused on financial data analysis, offering insights into how natural language processing (NLP) can be utilized to simplify data querying processes.

## Features

- Integration of LangChain with SQL databases.
- Usage of natural language queries to interact with SQL databases.
- Demonstrations of advanced querying and data analysis in the financial sector.
- Examples and explanations of various LangChain features and tools.

## Getting Started

To get started with this demo, you will need to:

1. Clone this repository to your local machine.
2. Ensure you have Jupyter Notebook or Jupyter Lab installed.
3. Install necessary Python packages: `pandas`, `sqlite3`, `openai`, `langchain`, and others as required by the notebook.
4. Open `Medium_LangChain_Demo.ipynb` in Jupyter Notebook/Lab and run the cells.

## Prerequisites

- Python 3.x
- Jupyter Notebook or Jupyter Lab
- Basic understanding of SQL and Python.
- Familiarity with financial datasets (optional but beneficial).

## Installation

Clone the repository using:

```bash
git clone https://github.com/anjoGF/Medium_LangChain_Demo.git
```

In [1]:
%%bash
pip install pandas
pip install scikit-learn
pip install pandas-datareader
pip install plotly
pip install langchain langchain-experimental
pip install openai
pip install tiktoken
pip install faiss-cpu

Collecting langchain
  Downloading langchain-0.0.352-py3-none-any.whl (794 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 794.4/794.4 kB 4.2 MB/s eta 0:00:00
Collecting langchain-experimental
  Downloading langchain_experimental-0.0.47-py3-none-any.whl (162 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 163.0/163.0 kB 11.8 MB/s eta 0:00:00
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain)
  Downloading dataclasses_json-0.6.3-py3-none-any.whl (28 kB)
Collecting jsonpatch<2.0,>=1.33 (from langchain)
  Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB)
Collecting langchain-community<0.1,>=0.0.2 (from langchain)
  Downloading langchain_community-0.0.6-py3-none-any.whl (1.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.5/1.5 MB 21.9 MB/s eta 0:00:00
Collecting langchain-core<0.2,>=0.1 (from langchain)
  Downloading langchain_core-0.1.3-py3-none-any.whl (192 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 192.4/192.4 kB 16.5 MB/s eta 0:00:00
Collecting langsmith<0.1

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
llmx 0.0.15a0 requires cohere, which is not installed.
llmx 0.0.15a0 requires tiktoken, which is not installed.
orbax-checkpoint 0.4.4 requires jax>=0.4.9, but you have jax 0.3.25 which is incompatible.
tensorflow-probability 0.22.0 requires typing-extensions<4.6.0, but you have typing-extensions 4.9.0 which is incompatible.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
llmx 0.0.15a0 requires cohere, which is not installed.


In [4]:
import sqlite3

def create_tables(db_name):
    with sqlite3.connect(db_name) as conn:
        cursor = conn.cursor()

        # Create economic_indicators table
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS economic_indicators (
                Date TEXT PRIMARY KEY,
                UNRATE REAL,
                PAYEMS REAL,
                ICSA REAL,
                CIVPART REAL
            )
        ''')

        # Create yield_curve_prices table
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS yield_curve_prices (
                Date TEXT PRIMARY KEY,
                DGS1MO REAL,
                DGS3MO REAL,
                DGS6MO REAL,
                DGS1 REAL,
                DGS2 REAL,
                DGS3 REAL,
                DGS5 REAL,
                DGS7 REAL,
                DGS10 REAL,
                DGS20 REAL,
                DGS30 REAL
            )
        ''')

        # Modify production_data table to match yield_curve_prices schema
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS production_data (
                Date TEXT PRIMARY KEY,
                SAUNGDPMOMBD REAL,
                ARENGDPMOMBD REAL,
                IRNNGDPMOMBD REAL,
                SAUNXGO REAL,
                QATNGDPMOMBD REAL,
                KAZNGDPMOMBD REAL,
                IRQNXGO REAL,
                IRNNXGO REAL,
                KWTNGDPMOMBD REAL,
                IPN213111S REAL,
                PCU213111213111 REAL
            )
        ''')

        # Create business_cycles table
        cursor.execute('''
            CREATE TABLE business_cycles (
                Peak_Month TEXT,
                Trough_Month TEXT,
                Start_Date TEXT,
                End_Date TEXT,
                Phase TEXT
            )
        ''')

        print("Tables created successfully.")

create_tables("financial_data.db")

Tables created successfully.


In [5]:
def insert_business_cycle_data(db_name):
    business_cycles = [
        {"peak": "1999-03-01", "trough": "2001-03-01", "start": "1999-03-01 00:00:00", "end": "2001-03-01 00:00:00", "phase": "Expansion"},
        {"peak": "2001-03-01", "trough": "2001-11-01", "start": "2001-03-01 00:00:00", "end": "2001-11-01 00:00:00", "phase": "Contraction"},
        {"peak": "2001-11-01", "trough": "2007-12-01", "start": "2001-11-01 00:00:00", "end": "2007-12-01 00:00:00", "phase": "Expansion"},
        {"peak": "2007-12-01", "trough": "2009-06-01", "start": "2007-12-01 00:00:00", "end": "2009-06-01 00:00:00", "phase": "Contraction"},
        {"peak": "2009-06-01", "trough": "2020-02-01", "start": "2009-06-01 00:00:00", "end": "2020-02-01 00:00:00", "phase": "Expansion"},
        {"peak": "2020-02-01", "trough": "2020-04-01", "start": "2020-02-01 00:00:00", "end": "2020-04-01 00:00:00", "phase": "Contraction"},
        {"peak": "2020-04-01", "trough": "2023-12-01", "start": "2020-04-01 00:00:00", "end": "2023-12-01 00:00:00", "phase": "Expansion"}
    ]

    with sqlite3.connect(db_name) as conn:
        cursor = conn.cursor()
        for cycle in business_cycles:
            cursor.execute('''
                INSERT INTO business_cycles (Peak_Month, Trough_Month, Start_Date, End_Date, Phase)
                VALUES (?, ?, ?, ?, ?)
            ''', (cycle["peak"], cycle["trough"], cycle["start"], cycle["end"], cycle["phase"]))
        print("Business cycle data inserted successfully.")

insert_business_cycle_data("financial_data.db")


Business cycle data inserted successfully.


In [6]:
import pandas as pd
import pandas_datareader.data as web
import sqlite3
from datetime import datetime

class DataLoader:
    def __init__(self, db_name="financial_data.db"):
        self.db_name = db_name
        self.economic_indicators_tickers = ['UNRATE', 'PAYEMS', 'ICSA', 'CIVPART']
        self.yield_curve_tickers = ['DGS1MO', 'DGS3MO', 'DGS6MO', 'DGS1', 'DGS2', 'DGS3', 'DGS5', 'DGS7', 'DGS10', 'DGS20', 'DGS30']
        self.production_data_tickers = ['SAUNGDPMOMBD', 'ARENGDPMOMBD', 'IRNNGDPMOMBD', 'SAUNXGO', 'QATNGDPMOMBD', 'KAZNGDPMOMBD', 'IRQNXGO', 'IRNNXGO', 'KWTNGDPMOMBD', 'IPN213111S', 'PCU213111213111']

    def fetch_and_insert_data(self, tickers, table_name):
        start_date = '2000-12-31'
        end_date = datetime.now().strftime('%Y-%m-%d')
        try:
            data = web.DataReader(tickers, 'fred', start_date, end_date)
            data = data.interpolate(method='quadratic').ffill().bfill()
            data = data.resample('D').ffill().bfill()

            with sqlite3.connect(self.db_name) as conn:
                data.to_sql(table_name, conn, if_exists='replace', index_label='Date')
            print(f"Data inserted into {table_name} table.")
        except Exception as e:
            print(f"Failed to fetch and insert data: {e}")

def main():
    db_name = "financial_data.db"
    loader = DataLoader(db_name)
    loader.fetch_and_insert_data(loader.economic_indicators_tickers, 'economic_indicators')
    loader.fetch_and_insert_data(loader.yield_curve_tickers, 'yield_curve_prices')
    loader.fetch_and_insert_data(loader.production_data_tickers, 'production_data')

if __name__ == "__main__":
    main()


Data inserted into economic_indicators table.
Data inserted into yield_curve_prices table.
Data inserted into production_data table.


In [7]:
import sqlite3
import pandas as pd

# Database name
db_name = "financial_data.db"

# Connect to the SQLite database
conn = sqlite3.connect(db_name)

# Fetch data from economic_indicators table
economic_indicators_query = "SELECT * FROM economic_indicators"
economic_indicators_df = pd.read_sql(economic_indicators_query, conn)
print("Economic Indicators Data:\n", economic_indicators_df.head())

# # Fetch data from yield_curve_prices table
yield_curve_prices_query = "SELECT * FROM yield_curve_prices"
yield_curve_prices_df = pd.read_sql(yield_curve_prices_query, conn)
print("\nYield Curve Prices Data:\n", yield_curve_prices_df.head())

# Fetch data from production_data table
production_data_query = "SELECT * FROM production_data"
production_data_df = pd.read_sql(production_data_query, conn)
print("\nProduction Data:\n", production_data_df.head())

# Fetch data from business_cycles table
business_cycles_query = "SELECT * FROM business_cycles"
business_cycles_df = pd.read_sql(business_cycles_query, conn)
print("\nBusiness Cycles Data:\n", business_cycles_df.head())

# Close the database connection
conn.close()

Economic Indicators Data:
                   Date  UNRATE    PAYEMS      ICSA  CIVPART
0  2001-01-01 00:00:00     4.2  132698.0  337000.0     67.2
1  2001-01-02 00:00:00     4.2  132698.0  337000.0     67.2
2  2001-01-03 00:00:00     4.2  132698.0  337000.0     67.2
3  2001-01-04 00:00:00     4.2  132698.0  337000.0     67.2
4  2001-01-05 00:00:00     4.2  132698.0  337000.0     67.2


In [None]:
import sqlite3
import pandas as pd
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
import openai

# Set your OpenAI API key
API_KEY="your-key"
openai.api_key = API_KEY

# Initialize the SQLite database
db_uri = "sqlite:///financial_data.db"  # Path to your SQLite database
db = SQLDatabase.from_uri(db_uri)

# Initialize the LangChain LLM (Large Language Model) with OpenAI
llm = OpenAI(openai_api_key=API_KEY, temperature=0.2, verbose=True)

# Create a SQLDatabaseChain instance for building and executing SQL queries
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

# Example: Generate and execute a SQL query from natural language question
# Replace the string with your natural language question
natural_language_question = "Describe the median short rates during the last year?"

# Generate and execute SQL query
response = db_chain.run(natural_language_question)
print("Response:", response)




[1m> Entering new SQLDatabaseChain chain...[0m
Describe the median short rates during the last year?
SQLQuery:[32;1m[1;3mSELECT AVG(DGS1MO) FROM yield_curve_prices WHERE Date BETWEEN date('now', '-1 year') AND date('now')[0m
SQLResult: [33;1m[1;3m[(5.1122075716281135,)][0m
Answer:[32;1m[1;3mThe median short rates during the last year was 5.1122075716281135.[0m
[1m> Finished chain.[0m
Response: The median short rates during the last year was 5.1122075716281135.


##SQL Agents

In [None]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType


# Create an SQL Agent
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=SQLDatabaseToolkit(db=db, llm=llm),
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    top_k=20
)

# Example query using the SQL Agent
query_result = agent_executor.run(
    "Discuss the yield curve, and their median values, over the last economic cycle."
)

print("Query Result:", query_result)

# Close the database connection
conn = sqlite3.connect(db_name)
conn.close()



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the schema of the yield_curve_prices table.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TIMESTAMP, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I should query the median

In [None]:
from langchain.utilities import SQLDatabase
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType


agent_executor.run("Describe the yield_curve_prices table")




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the schema of the yield_curve_prices table.
Action: sql_db_schema
Action Input: yield_curve_prices[0m
Observation: [33;1m[1;3m
CREATE TABLE yield_curve_prices (
	"Date" TIMESTAMP, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 rows from yield_curve_prices table:
Date	DGS1MO	DGS3MO	DGS6MO	DGS1	DGS2	DGS3	DGS5	DGS7	DGS10	DGS20	DGS30
2001-01-01 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-02 00:00:00	3.67	5.87	5.58	5.11	4.87	4.82	4.76	4.97	4.92	5.46	5.35
2001-01-03 00:00:00	3.67	5.69	5.44	5.04	4.92	4.92	4.94	5.18	5.14	5.62	5.49
*/[0m
Thought:[32;1m[1;3m I now know the final answ

'The yield_curve_prices table contains information on the daily yield curve rates for various maturities. It has 12 columns, including the date and rates for 1 month, 3 months, 6 months, 1 year, 2 years, 3 years, 5 years, 7 years, 10 years, 20 years, and 30 years.'

##**Creation of the vector database with FewShot Examples**

In [None]:
import sqlite3
import pandas as pd
from langchain.llms import OpenAI
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
from langchain.agents.agent_toolkits import create_retriever_tool
import openai

# Set your OpenAI API key
API_KEY="your-key"
openai.api_key = API_KEY

# Initialize the SQLite database
db_uri = "sqlite:///financial_data.db"
db = SQLDatabase.from_uri(db_uri)

# Initialize the LangChain LLM with OpenAI
llm = OpenAI(openai_api_key=API_KEY, temperature=0.2, verbose=True)

# Define few-shot examples specific to your database
few_shots = {
    "What was the highest unemployment rate last year?": "SELECT MAX(UNRATE) FROM economic_indicators WHERE Date >= '2022-01-01' AND Date <= '2022-12-31';",
    "Average industrial production for the previous month?": "SELECT AVG(IPN213111S) FROM production_data WHERE Date >= date('now', 'start of month', '-1 month') AND Date < date('now', 'start of month');",
    "Show the five lowest 10-year yield rates of the current year.": "SELECT * FROM yield_curve_prices WHERE Date >= '2023-01-01' ORDER BY DGS10 ASC LIMIT 5;",
    "List the economic indicators for the first quarter of 2023.": "SELECT * FROM economic_indicators WHERE Date >= '2023-01-01' AND Date <= '2023-03-31';",
    "What are the latest available production numbers for Saudi Arabia?": "SELECT SAUNGDPMOMBD FROM production_data ORDER BY Date DESC LIMIT 1;",
    "Compare the unemployment rate at the beginning and end of the last recession.": "SELECT UNRATE FROM economic_indicators WHERE Date IN (SELECT Start_Date FROM business_cycles WHERE Phase = 'Contraction' ORDER BY Start_Date DESC LIMIT 1) OR Date IN (SELECT End_Date FROM business_cycles WHERE Phase = 'Contraction' ORDER BY End_Date DESC LIMIT 1);",
    "Find the average civilian labor force participation rate for the last year.": "SELECT AVG(CIVPART) FROM economic_indicators WHERE Date >= '2022-01-01' AND Date <= '2022-12-31';",
    "Show the change in 2-year yield rates over the past six months.": "SELECT DGS2 FROM yield_curve_prices WHERE Date >= date('now', '-6 months') ORDER BY Date;",
    "What was the maximum production of natural gas in Qatar last year?": "SELECT MAX(QATNGDPMOMBD) FROM production_data WHERE Date >= '2022-01-01' AND Date <= '2022-12-31';",
    "List the top 3 longest economic expansions since 2000.": "SELECT Start_Date, End_Date FROM business_cycles WHERE Phase = 'Expansion' AND Start_Date >= '2000-01-01' ORDER BY (julianday(End_Date) - julianday(Start_Date)) DESC LIMIT 3;"
}

# Create a retriever for few-shot examples
embeddings = OpenAIEmbeddings(openai_api_key=API_KEY)
few_shot_docs = [Document(page_content=question, metadata={"sql_query": few_shots[question]}) for question in few_shots.keys()]
vector_db = FAISS.from_documents(few_shot_docs, embeddings)
retriever = vector_db.as_retriever()

# Create a custom tool for retrieving similar examples
retriever_tool = create_retriever_tool(retriever, name="sql_get_similar_examples", description="Retrieves similar SQL examples.")

# Create the SQL Agent with the custom tool
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    extra_tools=[retriever_tool],
    top_k=20
)

# Example query using the SQL Agent
query_result = agent_executor.run(
    "Discuss the yield curve, and their median values, over the last economic cycle."
)
print("Query Result:", query_result)

# Close the database connection
conn = sqlite3.connect(db_name)
conn.close()




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input:
[0m
Observation: [38;5;200m[1;3mbusiness_cycles, economic_indicators, production_data, yield_curve_prices[0m
Thought:[32;1m[1;3m I should query the yield_curve_prices table and the economic_indicators table.
Action: sql_db_schema
Action Input: yield_curve_prices, economic_indicators[0m
Observation: [33;1m[1;3m
CREATE TABLE economic_indicators (
	"Date" TIMESTAMP, 
	"UNRATE" REAL, 
	"PAYEMS" REAL, 
	"ICSA" REAL, 
	"CIVPART" REAL
)

/*
3 rows from economic_indicators table:
Date	UNRATE	PAYEMS	ICSA	CIVPART
2001-01-01 00:00:00	4.2	132698.0	337000.0	67.2
2001-01-02 00:00:00	4.2	132698.0	337000.0	67.2
2001-01-03 00:00:00	4.2	132698.0	337000.0	67.2
*/


CREATE TABLE yield_curve_prices (
	"Date" TIMESTAMP, 
	"DGS1MO" REAL, 
	"DGS3MO" REAL, 
	"DGS6MO" REAL, 
	"DGS1" REAL, 
	"DGS2" REAL, 
	"DGS3" REAL, 
	"DGS5" REAL, 
	"DGS7" REAL, 
	"DGS10" REAL, 
	"DGS20" REAL, 
	"DGS30" REAL
)

/*
3 

Collecting tiktoken
  Downloading tiktoken-0.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tiktoken
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
llmx 0.0.15a0 requires cohere, which is not installed.[0m[31m
[0mSuccessfully installed tiktoken-0.5.2
