# 1. Setup

In [21]:
import sqlite3
import pandas as pd
import os
import dotenv
import json
from typing import List

from llama_index.experimental.query_engine import PandasQueryEngine
from langchain_experimental.agents.agent_toolkits import create_csv_agent, create_pandas_dataframe_agent
from langchain_community.llms.openai import OpenAI
from pandasai import Agent
from pandasai.llm import OpenAI
from pandasai import SmartDataframe
from langchain_core.messages.human import HumanMessage
from langchain_community.chat_models.openai import ChatOpenAI
from langchain.agents.agent_types import AgentType

dotenv.load_dotenv()

True

# 2. Baselines with LangChain

## 2.1 Fewshot Examples + ICL

In [2]:
llm = ChatOpenAI(model="gpt-3.5-turbo")

  warn_deprecated(


In [3]:
examples = """Task: Plot the monthly average temperature data over a year.

import matplotlib.pyplot as plt

# Example data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
temperatures = [3, 4, 8, 12, 16, 20, 22, 21, 18, 13, 8, 4]

# Plotting the data
plt.figure(figsize=(10, 5))
plt.plot(months, temperatures, marker='o', linestyle='-', color='b')
plt.title('Monthly Average Temperatures')
plt.xlabel('Month')
plt.ylabel('Temperature (°C)')
plt.grid(True)
plt.show()

Task: Plot the total sales data for different products.

import matplotlib.pyplot as plt

# Example data
products = ['Laptop', 'Smartphone', 'Tablet']
sales = [150, 200, 100]

# Plotting the data
plt.figure(figsize=(7, 5))
plt.bar(products, sales, color='green')
plt.title('Product Sales Comparison')
plt.xlabel('Products')
plt.ylabel('Total Sales')
plt.show()

Task: Plot a scatter plot to analyze the relationship between hours studied and exam scores.

import matplotlib.pyplot as plt

# Example data
hours_studied = [1, 2, 3, 4, 5, 6, 7, 8]
exam_scores = [55, 60, 65, 70, 75, 85, 90, 95]

# Plotting the data
plt.figure(figsize=(8, 5))
plt.scatter(hours_studied, exam_scores, color='red')
plt.title('Study Time vs. Exam Scores')
plt.xlabel('Hours Studied')
plt.ylabel('Exam Score')
plt.show()

Task: Plot a histogram to observe the distribution of daily visitors to a museum over a month.

import matplotlib.pyplot as plt
import numpy as np

# Example data
visitors = np.random.normal(250, 50, 30)  # Simulating random daily visitor counts

# Plotting the data
plt.figure(figsize=(8, 5))
plt.hist(visitors, bins=10, color='purple')
plt.title('Daily Visitors Distribution')
plt.xlabel('Visitors')
plt.ylabel('Frequency')
plt.show()"""

prompt = """Solve a Python data exploration task.

Here are some examples:
{examples}
(END OF EXAMPLES)

The following Python data exploration task is on a pandas DataFrame on wildfires. Below are the columns.

Fires: Table including wildfire data for the period of 1992-2015 compiled from US federal, state, and local reporting systems.
FOD_ID = Global unique identifier.
FPA_ID = Unique identifier that contains information necessary to track back to the original record in the source dataset.
SOURCE_SYSTEM_TYPE = Type of source database or system that the record was drawn from (federal, nonfederal, or interagency).
SOURCE_SYSTEM = Name of or other identifier for source database or system that the record was drawn from. See Table 1 in Short (2014), or \Supplements\FPA_FOD_source_list.pdf, for a list of sources and their identifier.
NWCG_REPORTING_AGENCY = Active National Wildlife Coordinating Group (NWCG) Unit Identifier for the agency preparing the fire report (BIA = Bureau of Indian Affairs, BLM = Bureau of Land Management, BOR = Bureau of Reclamation, DOD = Department of Defense, DOE = Department of Energy, FS = Forest Service, FWS = Fish and Wildlife Service, IA = Interagency Organization, NPS = National Park Service, ST/C&L = State, County, or Local Organization, and TRIBE = Tribal Organization).
NWCG_REPORTING_UNIT_ID = Active NWCG Unit Identifier for the unit preparing the fire report.
NWCG_REPORTING_UNIT_NAME = Active NWCG Unit Name for the unit preparing the fire report.
SOURCE_REPORTING_UNIT = Code for the agency unit preparing the fire report, based on code/name in the source dataset.
SOURCE_REPORTING_UNIT_NAME = Name of reporting agency unit preparing the fire report, based on code/name in the source dataset.
LOCAL_FIRE_REPORT_ID = Number or code that uniquely identifies an incident report for a particular reporting unit and a particular calendar year.
LOCAL_INCIDENT_ID = Number or code that uniquely identifies an incident for a particular local fire management organization within a particular calendar year.
FIRE_CODE = Code used within the interagency wildland fire community to track and compile cost information for emergency fire suppression (https://www.firecode.gov/).
FIRE_NAME = Name of the incident, from the fire report (primary) or ICS-209 report (secondary).
ICS_209_INCIDENT_NUMBER = Incident (event) identifier, from the ICS-209 report.
ICS_209_NAME = Name of the incident, from the ICS-209 report.
MTBS_ID = Incident identifier, from the MTBS perimeter dataset.
MTBS_FIRE_NAME = Name of the incident, from the MTBS perimeter dataset.
COMPLEX_NAME = Name of the complex under which the fire was ultimately managed, when discernible.
FIRE_YEAR = Calendar year in which the fire was discovered or confirmed to exist.
DISCOVERY_DATE = Date on which the fire was discovered or confirmed to exist.
DISCOVERY_DOY = Day of year on which the fire was discovered or confirmed to exist.
DISCOVERY_TIME = Time of day that the fire was discovered or confirmed to exist.
STAT_CAUSE_CODE = Code for the (statistical) cause of the fire.
STAT_CAUSE_DESCR = Description of the (statistical) cause of the fire.
CONT_DATE = Date on which the fire was declared contained or otherwise controlled (mm/dd/yyyy where mm=month, dd=day, and yyyy=year).
CONT_DOY = Day of year on which the fire was declared contained or otherwise controlled.
CONT_TIME = Time of day that the fire was declared contained or otherwise controlled (hhmm where hh=hour, mm=minutes).
FIRE_SIZE = Estimate of acres within the final perimeter of the fire.
FIRE_SIZE_CLASS = Code for fire size based on the number of acres within the final fire perimeter expenditures (A=greater than 0 but less than or equal to 0.25 acres, B=0.26-9.9 acres, C=10.0-99.9 acres, D=100-299 acres, E=300 to 999 acres, F=1000 to 4999 acres, and G=5000+ acres).
LATITUDE = Latitude (NAD83) for point location of the fire (decimal degrees).
LONGITUDE = Longitude (NAD83) for point location of the fire (decimal degrees).
OWNER_CODE = Code for primary owner or entity responsible for managing the land at the point of origin of the fire at the time of the incident.
OWNER_DESCR = Name of primary owner or entity responsible for managing the land at the point of origin of the fire at the time of the incident.
STATE = Two-letter alphabetic code for the state in which the fire burned (or originated), based on the nominal designation in the fire report.
COUNTY = County, or equivalent, in which the fire burned (or originated), based on nominal designation in the fire report.
FIPS_CODE = Three-digit code from the Federal Information Process Standards (FIPS) publication 6-4 for representation of counties and equivalent entities.
FIPS_NAME = County name from the FIPS publication 6-4 for representation of counties and equivalent entities.
NWCG_UnitIDActive_20170109: Look-up table containing all NWCG identifiers for agency units that were active (i.e., valid) as of 9 January 2017, when the list was downloaded from https://www.nifc.blm.gov/unit_id/Publish.html and used as the source of values available to populate the following fields in the Fires table: NWCG_REPORTING_AGENCY, NWCG_REPORTING_UNIT_ID, and NWCG_REPORTING_UNIT_NAME.
UnitId = NWCG Unit ID.
GeographicArea = Two-letter code for the geographic area in which the unit is located (NA=National, IN=International, AK=Alaska, CA=California, EA=Eastern Area, GB=Great Basin, NR=Northern Rockies, NW=Northwest, RM=Rocky Mountain, SA=Southern Area, and SW=Southwest).
Gacc = Seven or eight-letter code for the Geographic Area Coordination Center in which the unit is located or primarily affiliated with (CAMBCIFC=Canadian Interagency Forest Fire Centre, USAKCC=Alaska Interagency Coordination Center, USCAONCC=Northern California Area Coordination Center, USCAOSCC=Southern California Coordination Center, USCORMCC=Rocky Mountain Area Coordination Center, USGASAC=Southern Area Coordination Center, USIDNIC=National Interagency Coordination Center, USMTNRC=Northern Rockies Coordination Center, USNMSWC=Southwest Area Coordination Center, USORNWC=Northwest Area Coordination Center, USUTGBC=Western Great Basin Coordination Center, USWIEACC=Eastern Area Coordination Center).
WildlandRole = Role of the unit within the wildland fire community.
UnitType = Type of unit (e.g., federal, state, local).
Department = Department (or state/territory) to which the unit belongs (AK=Alaska, AL=Alabama, AR=Arkansas, AZ=Arizona, CA=California, CO=Colorado, CT=Connecticut, DE=Delaware, DHS=Department of Homeland Security, DOC= Department of Commerce, DOD=Department of Defense, DOE=Department of Energy, DOI= Department of Interior, DOL=Department of Labor, FL=Florida, GA=Georgia, IA=Iowa, IA/GC=Non-Departmental Agencies, ID=Idaho, IL=Illinois, IN=Indiana, KS=Kansas, KY=Kentucky, LA=Louisiana, MA=Massachusetts, MD=Maryland, ME=Maine, MI=Michigan, MN=Minnesota, MO=Missouri, MS=Mississippi, MT=Montana, NC=North Carolina, NE=Nebraska, NG=Non-Government, NH=New Hampshire, NJ=New Jersey, NM=New Mexico, NV=Nevada, NY=New York, OH=Ohio, OK=Oklahoma, OR=Oregon, PA=Pennsylvania, PR=Puerto Rico, RI=Rhode Island, SC=South Carolina, SD=South Dakota, ST/L=State or Local Government, TN=Tennessee, Tribe=Tribe, TX=Texas, USDA=Department of Agriculture, UT=Utah, VA=Virginia, VI=U. S. Virgin Islands, VT=Vermont, WA=Washington, WI=Wisconsin, WV=West Virginia, WY=Wyoming).
Agency = Agency or bureau to which the unit belongs (AG=Air Guard, ANC=Alaska Native Corporation, BIA=Bureau of Indian Affairs, BLM=Bureau of Land Management, BOEM=Bureau of Ocean Energy Management, BOR=Bureau of Reclamation, BSEE=Bureau of Safety and Environmental Enforcement, C&L=County & Local, CDF=California Department of Forestry & Fire Protection, DC=Department of Corrections, DFE=Division of Forest Environment, DFF=Division of Forestry Fire & State Lands, DFL=Division of Forests and Land, DFR=Division of Forest Resources, DL=Department of Lands, DNR=Department of Natural Resources, DNRC=Department of Natural Resources and Conservation, DNRF=Department of Natural Resources Forest Service, DOA=Department of Agriculture, DOC=Department of Conservation, DOE=Department of Energy, DOF=Department of Forestry, DVF=Division of Forestry, DWF=Division of Wildland Fire, EPA=Environmental Protection Agency, FC=Forestry Commission, FEMA=Federal Emergency Management Agency, FFC=Bureau of Forest Fire Control, FFP=Forest Fire Protection, FFS=Forest Fire Service, FR=Forest Rangers, FS=Forest Service, FWS=Fish & Wildlife Service, HQ=Headquarters, JC=Job Corps, NBC=National Business Center, NG=National Guard, NNSA=National Nuclear Security Administration, NPS=National Park Service, NWS=National Weather Service, OES=Office of Emergency Services, PRI=Private, SF=State Forestry, SFS=State Forest Service, SP=State Parks, TNC=The Nature Conservancy, USA=United States Army, USACE=United States Army Corps of Engineers, USAF=United States Air Force, USGS=United States Geological Survey, USN=United States Navy).
Parent = Agency subgroup to which the unit belongs (A concatenation of State and Unit from this report - https://www.nifc.blm.gov/unit_id/publish/UnitIdReport.rtf).
Country = Country in which the unit is located (e.g. US = United States).
State = Two-letter code for the state in which the unit is located (or primarily affiliated).
Code = Unit code (follows state code to create UnitId).
Name = Unit name.

Below are the datatypes for each column.

RangeIndex: 1880465 entries, 0 to 1880464
Data columns (total 39 columns):
 #   Column                      Dtype
---  ------                      -----
 0   OBJECTID                    int64
 1   FOD_ID                      int64
 2   FPA_ID                      object
 3   SOURCE_SYSTEM_TYPE          object
 4   SOURCE_SYSTEM               object
 5   NWCG_REPORTING_AGENCY       object
 6   NWCG_REPORTING_UNIT_ID      object
 7   NWCG_REPORTING_UNIT_NAME    object
 8   SOURCE_REPORTING_UNIT       object
 9   SOURCE_REPORTING_UNIT_NAME  object
 10  LOCAL_FIRE_REPORT_ID        object
 11  LOCAL_INCIDENT_ID           object
 12  FIRE_CODE                   object
 13  FIRE_NAME                   object
 14  ICS_209_INCIDENT_NUMBER     object
 15  ICS_209_NAME                object
 16  MTBS_ID                     object
 17  MTBS_FIRE_NAME              object
 18  COMPLEX_NAME                object
 19  FIRE_YEAR                   int64
 20  DISCOVERY_DATE              datetime64[ns]
 21  DISCOVERY_DOY               int64
 22  DISCOVERY_TIME              float64
 23  STAT_CAUSE_CODE             float64
 24  STAT_CAUSE_DESCR            object
 25  CONT_DATE                   float64
 26  CONT_DOY                    float64
 27  CONT_TIME                   float64
 28  FIRE_SIZE                   float64
 29  FIRE_SIZE_CLASS             object
 30  LATITUDE                    float64
 31  LONGITUDE                   float64
 32  OWNER_CODE                  float64
 33  OWNER_DESCR                 object
 34  STATE                       object
 35  COUNTY                      object
 36  FIPS_CODE                   float64
 37  FIPS_NAME                   object
 38  Shape                       object
dtypes: datetime64[ns](1), float64(10), int64(4), object(24)

Generate Python code to manipulate the dataframe `df` with the above columns and plot the results. Let's think step by step.
Question: {question}

Code:

# Load in the csv.
df = pd.read_csv('wildfires_lite.csv')

"""

In [4]:
question = "What counties are the most and least fire-prone?"
formatted_prompt = prompt.format(examples=examples, question=question)
out = llm(
    [
        HumanMessage(
            content=formatted_prompt,
        )
    ]
).content

  warn_deprecated(


In [5]:
print(out)

# Group the dataframe by county and count the number of fires in each county
county_fires = df['COUNTY'].value_counts()

# Find the county with the most fires
most_fire_prone_county = county_fires.idxmax()
most_fires = county_fires.max()

# Find the county with the least fires
least_fire_prone_county = county_fires.idxmin()
least_fires = county_fires.min()

# Plotting the results
plt.figure(figsize=(10, 5))
county_fires.plot(kind='bar', color='orange')
plt.title('Number of Fires by County')
plt.xlabel('County')
plt.ylabel('Number of Fires')
plt.xticks(rotation=45)
plt.show()

print(f"The most fire-prone county is {most_fire_prone_county} with {most_fires} fires.")
print(f"The least fire-prone county is {least_fire_prone_county} with {least_fires} fires.")


## 2.2 `create_csv_agent`

In [14]:
agent = create_csv_agent(
    ChatOpenAI(temperature=0),
    "data/wildfires_lite.csv",
    verbose=True,
)
agent.invoke({"input": question})



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: To determine the most and least fire-prone counties, we need to analyze the frequency of fires in each county.
Action: We can use the `value_counts()` function on the 'COUNTY' column of the dataframe to get the count of fires in each county.
Action Input: df['COUNTY'].value_counts()[0mWe can use the `value_counts()` function on the 'COUNTY' column of the dataframe to get the count of fires in each county. is not a valid tool, try one of [python_repl_ast].[32;1m[1;3mI should use the python_repl_ast tool to execute the command and get the count of fires in each county.
Action: python_repl_ast
Action Input: df['COUNTY'].value_counts()[0m[36;1m[1;3m5.0      824
3.0      441
39.0     369
7.0      360
17.0     339
        ... 
189.0      1
261.0      1
119.0      1
135.0      1
159.0      1
Name: COUNTY, Length: 117, dtype: int64[0m[32;1m[1;3mI now know the final answer
Final Answer: The most fire-prone county is 

{'input': 'What counties are the most and least fire-prone?',
 'output': 'The most fire-prone county is county 5 with 824 fires, and the least fire-prone counties are counties 189, 261, 119, 135, and 159 with only 1 fire each.'}

## 2.3 `create_pandas_dataframe_agent`

In [16]:
df = pd.read_csv("data/wildfires_lite.csv")
agent = create_pandas_dataframe_agent(ChatOpenAI(temperature=0), df, verbose=True)

In [17]:
agent.invoke({"input": question})



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: To determine the most and least fire-prone counties, we need to analyze the frequency of fires in each county.
Action: We can use the `value_counts()` function on the 'COUNTY' column of the dataframe to get the count of fires in each county.
Action Input: df['COUNTY'].value_counts()[0mWe can use the `value_counts()` function on the 'COUNTY' column of the dataframe to get the count of fires in each county. is not a valid tool, try one of [python_repl_ast].[32;1m[1;3mWe can use the Python shell to get the count of fires in each county.
Action: python_repl_ast
Action Input: df['COUNTY'].value_counts()[0m[36;1m[1;3m5.0      824
3.0      441
39.0     369
7.0      360
17.0     339
        ... 
189.0      1
261.0      1
119.0      1
135.0      1
159.0      1
Name: COUNTY, Length: 117, dtype: int64[0m

ValueError: An output parsing error occurred. In order to pass this error back to the agent and have it try again, pass `handle_parsing_errors=True` to the AgentExecutor. This is the error: Parsing LLM output produced both a final answer and a parse-able action:: Now that we have the count of fires in each county, we can determine the most and least fire-prone counties based on these counts.
Action: Calculate the most and least fire-prone counties
Action Input: Most fire-prone county: 5.0, Least fire-prone county: 189.0
Final Answer: The most fire-prone county is 5.0 and the least fire-prone county is 189.0.

# 3. LlamaIndex `PandasQueryEngine`

In [2]:

df = pd.read_csv("data/wildfires_lite.csv")
query_engine = PandasQueryEngine(df=df)
response = query_engine.query(
    "Have wildfires become more or less frequent over time?",
)

In [3]:
print(response.response)

2002       1
2004      38
2005    6856
2006    3104
2008       1
Name: FIRE_YEAR, dtype: int64


# 4. PandasAI and SmartDataframe

For this section, we will try out [PandasAI](https://pandas-ai.com/). To get started, create your `PANDASAI_API_KEY` [here](https://pandabi.ai). Optionally, you can also add it to the `.env` as `PANDASAI_API_KEY`.

In [5]:
os.environ["PANDASAI_API_KEY"] = input("Enter your PANDASAI_API_KEY: ")

In [6]:
df = pd.read_csv("data/wildfires_lite.csv")

In [7]:
llm = OpenAI(api_token=os.getenv("OPENAI_API_KEY"))
agent = Agent(df, config={"llm": llm})

In [8]:
questions = [
    "What is this dataset about?",
    "What questions can I ask?",
    "Have wildfires become more or less frequent over time?",
    "What counties are the most and least fire-prone?",
    "Given the size, location and date, can you predict the cause of a fire wildfire?",
]

agent.chat(questions[2])

'The trend of wildfires over time is decreasing.'

In [11]:
# You can instantiate a SmartDataframe with a path to a CSV file
sdf = SmartDataframe("data/wildfires_lite.csv", config={"llm": llm})

response = sdf.chat(questions[2])
print(response)

The trend of wildfires over time is decreasing.


# Plan-ReAct-Self-Refine (PRSR)

In [12]:
llm = ChatOpenAI()

question = "What counties are the most and least fire-prone?"

  warn_deprecated(


In [13]:
#@title Context

dataframe_context = """The following Python data exploration task is on a pandas DataFrame on wildfires. Below are the columns.

Fires: Table including wildfire data for the period of 1992-2015 compiled from US federal, state, and local reporting systems.
FOD_ID = Global unique identifier.
FPA_ID = Unique identifier that contains information necessary to track back to the original record in the source dataset.
SOURCE_SYSTEM_TYPE = Type of source database or system that the record was drawn from (federal, nonfederal, or interagency).
SOURCE_SYSTEM = Name of or other identifier for source database or system that the record was drawn from. See Table 1 in Short (2014), or \Supplements\FPA_FOD_source_list.pdf, for a list of sources and their identifier.
NWCG_REPORTING_AGENCY = Active National Wildlife Coordinating Group (NWCG) Unit Identifier for the agency preparing the fire report (BIA = Bureau of Indian Affairs, BLM = Bureau of Land Management, BOR = Bureau of Reclamation, DOD = Department of Defense, DOE = Department of Energy, FS = Forest Service, FWS = Fish and Wildlife Service, IA = Interagency Organization, NPS = National Park Service, ST/C&L = State, County, or Local Organization, and TRIBE = Tribal Organization).
NWCG_REPORTING_UNIT_ID = Active NWCG Unit Identifier for the unit preparing the fire report.
NWCG_REPORTING_UNIT_NAME = Active NWCG Unit Name for the unit preparing the fire report.
SOURCE_REPORTING_UNIT = Code for the agency unit preparing the fire report, based on code/name in the source dataset.
SOURCE_REPORTING_UNIT_NAME = Name of reporting agency unit preparing the fire report, based on code/name in the source dataset.
LOCAL_FIRE_REPORT_ID = Number or code that uniquely identifies an incident report for a particular reporting unit and a particular calendar year.
LOCAL_INCIDENT_ID = Number or code that uniquely identifies an incident for a particular local fire management organization within a particular calendar year.
FIRE_CODE = Code used within the interagency wildland fire community to track and compile cost information for emergency fire suppression (https://www.firecode.gov/).
FIRE_NAME = Name of the incident, from the fire report (primary) or ICS-209 report (secondary).
ICS_209_INCIDENT_NUMBER = Incident (event) identifier, from the ICS-209 report.
ICS_209_NAME = Name of the incident, from the ICS-209 report.
MTBS_ID = Incident identifier, from the MTBS perimeter dataset.
MTBS_FIRE_NAME = Name of the incident, from the MTBS perimeter dataset.
COMPLEX_NAME = Name of the complex under which the fire was ultimately managed, when discernible.
FIRE_YEAR = Calendar year in which the fire was discovered or confirmed to exist.
DISCOVERY_DATE = Date on which the fire was discovered or confirmed to exist.
DISCOVERY_DOY = Day of year on which the fire was discovered or confirmed to exist.
DISCOVERY_TIME = Time of day that the fire was discovered or confirmed to exist.
STAT_CAUSE_CODE = Code for the (statistical) cause of the fire.
STAT_CAUSE_DESCR = Description of the (statistical) cause of the fire.
CONT_DATE = Date on which the fire was declared contained or otherwise controlled (mm/dd/yyyy where mm=month, dd=day, and yyyy=year).
CONT_DOY = Day of year on which the fire was declared contained or otherwise controlled.
CONT_TIME = Time of day that the fire was declared contained or otherwise controlled (hhmm where hh=hour, mm=minutes).
FIRE_SIZE = Estimate of acres within the final perimeter of the fire.
FIRE_SIZE_CLASS = Code for fire size based on the number of acres within the final fire perimeter expenditures (A=greater than 0 but less than or equal to 0.25 acres, B=0.26-9.9 acres, C=10.0-99.9 acres, D=100-299 acres, E=300 to 999 acres, F=1000 to 4999 acres, and G=5000+ acres).
LATITUDE = Latitude (NAD83) for point location of the fire (decimal degrees).
LONGITUDE = Longitude (NAD83) for point location of the fire (decimal degrees).
OWNER_CODE = Code for primary owner or entity responsible for managing the land at the point of origin of the fire at the time of the incident.
OWNER_DESCR = Name of primary owner or entity responsible for managing the land at the point of origin of the fire at the time of the incident.
STATE = Two-letter alphabetic code for the state in which the fire burned (or originated), based on the nominal designation in the fire report.
COUNTY = County, or equivalent, in which the fire burned (or originated), based on nominal designation in the fire report.
FIPS_CODE = Three-digit code from the Federal Information Process Standards (FIPS) publication 6-4 for representation of counties and equivalent entities.
FIPS_NAME = County name from the FIPS publication 6-4 for representation of counties and equivalent entities.
NWCG_UnitIDActive_20170109: Look-up table containing all NWCG identifiers for agency units that were active (i.e., valid) as of 9 January 2017, when the list was downloaded from https://www.nifc.blm.gov/unit_id/Publish.html and used as the source of values available to populate the following fields in the Fires table: NWCG_REPORTING_AGENCY, NWCG_REPORTING_UNIT_ID, and NWCG_REPORTING_UNIT_NAME.
UnitId = NWCG Unit ID.
GeographicArea = Two-letter code for the geographic area in which the unit is located (NA=National, IN=International, AK=Alaska, CA=California, EA=Eastern Area, GB=Great Basin, NR=Northern Rockies, NW=Northwest, RM=Rocky Mountain, SA=Southern Area, and SW=Southwest).
Gacc = Seven or eight-letter code for the Geographic Area Coordination Center in which the unit is located or primarily affiliated with (CAMBCIFC=Canadian Interagency Forest Fire Centre, USAKCC=Alaska Interagency Coordination Center, USCAONCC=Northern California Area Coordination Center, USCAOSCC=Southern California Coordination Center, USCORMCC=Rocky Mountain Area Coordination Center, USGASAC=Southern Area Coordination Center, USIDNIC=National Interagency Coordination Center, USMTNRC=Northern Rockies Coordination Center, USNMSWC=Southwest Area Coordination Center, USORNWC=Northwest Area Coordination Center, USUTGBC=Western Great Basin Coordination Center, USWIEACC=Eastern Area Coordination Center).
WildlandRole = Role of the unit within the wildland fire community.
UnitType = Type of unit (e.g., federal, state, local).
Department = Department (or state/territory) to which the unit belongs (AK=Alaska, AL=Alabama, AR=Arkansas, AZ=Arizona, CA=California, CO=Colorado, CT=Connecticut, DE=Delaware, DHS=Department of Homeland Security, DOC= Department of Commerce, DOD=Department of Defense, DOE=Department of Energy, DOI= Department of Interior, DOL=Department of Labor, FL=Florida, GA=Georgia, IA=Iowa, IA/GC=Non-Departmental Agencies, ID=Idaho, IL=Illinois, IN=Indiana, KS=Kansas, KY=Kentucky, LA=Louisiana, MA=Massachusetts, MD=Maryland, ME=Maine, MI=Michigan, MN=Minnesota, MO=Missouri, MS=Mississippi, MT=Montana, NC=North Carolina, NE=Nebraska, NG=Non-Government, NH=New Hampshire, NJ=New Jersey, NM=New Mexico, NV=Nevada, NY=New York, OH=Ohio, OK=Oklahoma, OR=Oregon, PA=Pennsylvania, PR=Puerto Rico, RI=Rhode Island, SC=South Carolina, SD=South Dakota, ST/L=State or Local Government, TN=Tennessee, Tribe=Tribe, TX=Texas, USDA=Department of Agriculture, UT=Utah, VA=Virginia, VI=U. S. Virgin Islands, VT=Vermont, WA=Washington, WI=Wisconsin, WV=West Virginia, WY=Wyoming).
Agency = Agency or bureau to which the unit belongs (AG=Air Guard, ANC=Alaska Native Corporation, BIA=Bureau of Indian Affairs, BLM=Bureau of Land Management, BOEM=Bureau of Ocean Energy Management, BOR=Bureau of Reclamation, BSEE=Bureau of Safety and Environmental Enforcement, C&L=County & Local, CDF=California Department of Forestry & Fire Protection, DC=Department of Corrections, DFE=Division of Forest Environment, DFF=Division of Forestry Fire & State Lands, DFL=Division of Forests and Land, DFR=Division of Forest Resources, DL=Department of Lands, DNR=Department of Natural Resources, DNRC=Department of Natural Resources and Conservation, DNRF=Department of Natural Resources Forest Service, DOA=Department of Agriculture, DOC=Department of Conservation, DOE=Department of Energy, DOF=Department of Forestry, DVF=Division of Forestry, DWF=Division of Wildland Fire, EPA=Environmental Protection Agency, FC=Forestry Commission, FEMA=Federal Emergency Management Agency, FFC=Bureau of Forest Fire Control, FFP=Forest Fire Protection, FFS=Forest Fire Service, FR=Forest Rangers, FS=Forest Service, FWS=Fish & Wildlife Service, HQ=Headquarters, JC=Job Corps, NBC=National Business Center, NG=National Guard, NNSA=National Nuclear Security Administration, NPS=National Park Service, NWS=National Weather Service, OES=Office of Emergency Services, PRI=Private, SF=State Forestry, SFS=State Forest Service, SP=State Parks, TNC=The Nature Conservancy, USA=United States Army, USACE=United States Army Corps of Engineers, USAF=United States Air Force, USGS=United States Geological Survey, USN=United States Navy).
Parent = Agency subgroup to which the unit belongs (A concatenation of State and Unit from this report - https://www.nifc.blm.gov/unit_id/publish/UnitIdReport.rtf).
Country = Country in which the unit is located (e.g. US = United States).
State = Two-letter code for the state in which the unit is located (or primarily affiliated).
Code = Unit code (follows state code to create UnitId).
Name = Unit name.

Below are the datatypes for each column.

RangeIndex: 1880465 entries, 0 to 1880464
Data columns (total 39 columns):
 #   Column                      Dtype
---  ------                      -----
 0   OBJECTID                    int64
 1   FOD_ID                      int64
 2   FPA_ID                      object
 3   SOURCE_SYSTEM_TYPE          object
 4   SOURCE_SYSTEM               object
 5   NWCG_REPORTING_AGENCY       object
 6   NWCG_REPORTING_UNIT_ID      object
 7   NWCG_REPORTING_UNIT_NAME    object
 8   SOURCE_REPORTING_UNIT       object
 9   SOURCE_REPORTING_UNIT_NAME  object
 10  LOCAL_FIRE_REPORT_ID        object
 11  LOCAL_INCIDENT_ID           object
 12  FIRE_CODE                   object
 13  FIRE_NAME                   object
 14  ICS_209_INCIDENT_NUMBER     object
 15  ICS_209_NAME                object
 16  MTBS_ID                     object
 17  MTBS_FIRE_NAME              object
 18  COMPLEX_NAME                object
 19  FIRE_YEAR                   int64
 20  DISCOVERY_DATE              datetime64[ns]
 21  DISCOVERY_DOY               int64
 22  DISCOVERY_TIME              float64
 23  STAT_CAUSE_CODE             float64
 24  STAT_CAUSE_DESCR            object
 25  CONT_DATE                   float64
 26  CONT_DOY                    float64
 27  CONT_TIME                   float64
 28  FIRE_SIZE                   float64
 29  FIRE_SIZE_CLASS             object
 30  LATITUDE                    float64
 31  LONGITUDE                   float64
 32  OWNER_CODE                  float64
 33  OWNER_DESCR                 object
 34  STATE                       object
 35  COUNTY                      object
 36  FIPS_CODE                   float64
 37  FIPS_NAME                   object
 38  Shape                       object
dtypes: datetime64[ns](1), float64(10), int64(4), object(24)"""

In [14]:
#@title ReAct Prompt

react_prompt = """Solve a data analysis sub-task on a wildfires pd.DataFrame with interleaving Thought and Action steps.
Thoughts reason about the next appropriate step and Action is the code implementation based on the Thoughts separated by \\n.

## Context:
{context}

## Rules:
- Complete only the current step of the plan shown below.
- Do not add ```python ``` when generating code for Action.
- You have a maximum of {max_steps} steps.
- Action steps only include code.

## Plan:
{plan}

## Code:
```python

{code}
```

The current step is: {current_step}. Let's think step by step.

Question: {question}{memory}
"""

In [16]:
def generate_react(
  question: str,
  current_step: int,
  plan: str,
  memory: str = "",
  code: str = "",
  max_steps: int = 1,
) -> str:
  actions_taken = 0
  while actions_taken < max_steps:
    # Think.
    memory += '\nThought: '
    thought_prompt = react_prompt.format(
        max_steps=max_steps,
        context=dataframe_context,
        question=question,
        memory=memory,
        code=code,
        plan=plan,
        current_step=current_step
    )
    # print("THOUGHT PROMPT<===============================================>")
    # print(thought_prompt)
    # print("THOUGHT PROMPT<===============================================>")
    thought = llm(
        [
            HumanMessage(
                content=thought_prompt,
            )
        ]
    ).content.split("Action:")[0]
    # print("THOUGHT OUT<===============================================>")
    # print(thought)
    # print("THOUGHT OUT<===============================================>")
    memory += thought

    # Act.
    memory += '\nAction: '
    action_prompt = react_prompt.format(
        max_steps=max_steps,
        context=dataframe_context,
        question=question,
        memory=memory,
        code=code,
        plan=plan,
        current_step=current_step
    )
    # print("ACTION PROMPT<===============================================>")
    # print(action_prompt)
    # print("ACTION PROMPT<===============================================>")
    action = llm(
        [
            HumanMessage(
                content=action_prompt,
            )
        ]
    ).content
    action = "\n" + action.lstrip("```python").rstrip("```").lstrip("").strip() + "\n"
    # print("ACTION OUT<===============================================>")
    # print(action)
    # print("ACTION OUT<===============================================>")
    code += action
    memory += action

    actions_taken += 1

  return code

In [17]:
out = generate_react(question, current_step=0, plan="")  # Testing if it works.
print(out)

  warn_deprecated(



county_fire_freq = df['COUNTY'].value_counts()



In [18]:
#@title Refinement Prompt

refinement_prompt = """You have attempted to answer following data analysis question on a wildfire pd.DataFrame.

## Context:
{context}

You will be provided the plan and code.

## Plan:
{plan}

## Code:
{code}

There is an error in the code above because of lack of understanding of the question. What is the error?
To find the error, go through semantically complete blocks of the code, and check if everything looks good.
If there is no error, output only '<CORRECT>'"""


feedback_prompt = """You will be given a plan and code that answers a question. The code is erroneous and thus the suggested refinement is provided.

## Context:
{context}

Question: {question}

## Plan:
{plan}

## Code:
{code}

Refinement: {refinement}

Update the code with the refinement suggestions and output this new updated code."""

In [19]:
def generate_refinement(plan: str, code: str) -> str:
    prompt = refinement_prompt.format(
        context=dataframe_context,
        plan=plan,
        code=code
    )
    refinement = llm([HumanMessage(content=prompt)]).content
    return refinement

def refine_code(question: str, plan: str, code: str, refinement: str) -> str:
    prompt = feedback_prompt.format(
        context=dataframe_context,
        question=question,
        plan=plan,
        code=code,
        refinement=refinement
    )
    code = llm([HumanMessage(content=prompt)]).content
    return code

In [20]:
#@title Planner Prompt

planner_instruction_prompt = """As the Planner, your job is to decompose the user's data analytics question into a sequential list of substeps.
These steps should only be in natural language and understood by a Data Analyst/Scientist who can translate this into Python code.

{context}

The first step should be to import relevant libraries and read the csv file. The .csv file is located at {csv_file_path}.
The last step should be to return the answer(s).

Question: {question}
"""

In [22]:
def generate_plan(
  question: str,
  csv_file_path: str
) -> List[str]:
    prompt = planner_instruction_prompt.format(context=dataframe_context, question=question, csv_file_path=csv_file_path)
    plan = llm([HumanMessage(content=prompt)]).content
    plan = plan.split("\n")

    return plan

In [23]:
plan = generate_plan(question, "wildfires_lite.csv")
plan

['1. Import the pandas library.',
 '2. Read the csv file "wildfires_lite.csv" into a pandas DataFrame.',
 '3. Group the data by the "COUNTY" column.',
 '4. Calculate the total number of fires for each county.',
 '5. Sort the counties based on the total number of fires in descending order to find the most fire-prone counties.',
 '6. Identify the county(s) with the highest number of fires.',
 '7. Sort the counties based on the total number of fires in ascending order to find the least fire-prone counties.',
 '8. Identify the county(s) with the lowest number of fires.',
 '9. Return the county(s) identified as the most and least fire-prone.']

In [24]:
code = ""
for current_step, step in enumerate(plan, 1):
  # Execute ReAct on the current_step.
  code = generate_react(step, plan="\n".join(plan), current_step=current_step, code=code)

# Generate refinement at the end of the plan.
refinement = generate_refinement(plan, code)

# Consolidate the refinement/feedback and code into a final output.
code = refine_code(question, plan, code, refinement)

In [25]:
print(code)

```python
import pandas as pd

df = pd.read_csv("wildfires_lite.csv")

county_grouped = df.groupby("COUNTY")

county_fires_count = county_grouped.size()

county_fires_count_sorted = county_fires_count.sort_values(ascending=False)

most_fire_prone_counties = county_fires_count_sorted.head(1)

least_fire_prone_counties = county_fires_count_sorted.sort_values(ascending=True)

most_fire_prone_counties_names = most_fire_prone_counties.index.tolist()
least_fire_prone_counties_names = least_fire_prone_counties.index.tolist()

most_fire_prone_counties_names, least_fire_prone_counties_names
```


# TaskWeaver

In [32]:
# Copy the wildfires_lite.csv to the `sample_data` folder.
df = pd.read_csv("data/wildfires_lite.csv")
df.to_csv("TaskWeaver/project/sample_data/wildfires_lite.csv", index=False)

In [105]:
import sys
sys.path.append("TaskWeaver")

In [158]:
from taskweaver.app.app import TaskWeaverApp

# This is the folder that contains the taskweaver_config.json file and not the repo root. Defaults to "./project/"
app_dir = "./TaskWeaver/project/"
app = TaskWeaverApp(app_dir=app_dir)
session = app.get_session()

user_query = "calculate mean value of ./TaskWeaver/project/sample_data/demo_data.csv"
response_round = session.send_message(user_query)
print(response_round.to_dict())

{'id': 'round-20240418-055303-75c97fc5', 'user_query': 'calculate mean value of /project/sample_data/demo_data.csv', 'state': 'finished', 'post_list': [{'id': 'post-20240418-055303-8f49d791', 'message': 'calculate mean value of /project/sample_data/demo_data.csv', 'send_from': 'User', 'send_to': 'Planner', 'attachment_list': []}, {'id': 'post-20240418-055303-fef7b852', 'message': 'Please load the data file [/project/sample_data/demo_data.csv](/project/sample_data/demo_data.csv) and calculate the mean value of the loaded data', 'send_from': 'Planner', 'send_to': 'CodeInterpreter', 'attachment_list': [{'id': 'atta-20240418-055308-ba53250e', 'type': 'init_plan', 'content': '1. Load /project/sample_data/demo_data.csv file\n2. Calculate the mean value of the loaded data <sequentially depends on 1>\n3. Report the calculated mean value to the user <interactively depends on 2>', 'extra': None}, {'id': 'atta-20240418-055313-04bd8367', 'type': 'plan', 'content': '1. Load /project/sample_data/dem