In [8]:
# !pip install dspy langchain langchain_community
# !pip install rank_bm25

In [1]:
from src.get_place_dcids import place_dcid

In [2]:
import os
from langchain.schema import Document
import ast

docs = []
for stat_files in os.listdir("src/STATS"):
    stat_file_name = ".".join(stat_files.split("_"))
    with open(os.path.join("src/STATS",stat_files), "r") as f:
        content = f.read()
    content = ast.literal_eval(content)
    for stat in content:
        docs.append(Document(page_content=stat['node_name'],metadata={'dcid': stat['node_dcid'],'link': stat['node_link'],'data_source':stat_file_name}))

In [20]:
import dspy
from langchain_community.retrievers import BM25Retriever
from dotenv import load_dotenv,find_dotenv
import concurrent.futures
import datacommons_pandas as dc

load_dotenv(find_dotenv(),override=True)

class PlaceKeywordSignature(dspy.Signature):
    """Returns the places that the question is talking about separated by semicolon (;) and also only the noun keywords relevant to the question in a list
    Make sure that you are only outputing the noun keywords and not other things"""
    question = dspy.InputField(prefix="Question: ",desc="Question asked by the user")
    places = dspy.OutputField(prefix="Places: ",desc="places like countries, states, towns, etc mentioned in the question separated by semicolon (;)")
    keywords = dspy.OutputField(prefix="Keywords: ",desc="noun keywords relevant to the question in a list. DON'T include the place names and be precise")

llm = dspy.OpenAI(model="gpt-3.5-turbo")
dspy.settings.configure(lm=llm)

class DataCommonsDSPy(dspy.Module):
    def __init__(self):
        self.bm25_retriever = BM25Retriever.from_documents(
                docs, k=5, preprocess_func=(lambda x: x.lower())
            )
        self.place_keyword_llm = dspy.ChainOfThought(PlaceKeywordSignature)

    def __call__(self,question,**kwargs):
        return self.forward(question,**kwargs)

    def forward(self,question:str):
        llm_answer = self.place_keyword_llm(question=question)
        print(llm_answer)
        places = llm_answer.places.split(";")
        keywords = llm_answer.keywords.split(",")

        with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
            results = executor.map(place_dcid,places) 
        place_dcids = []
        for res in results:
            place_dcids.append(res)
        print(place_dcids)
        bm25_docs = self.bm25_retriever.invoke(question.lower())
        for key in keywords:
            curr_keyword_docs = self.bm25_retriever.invoke(key.lower())
            bm25_docs.extend(curr_keyword_docs)
        stat_dcids = []
        print(bm25_docs)
        for doc in bm25_docs:
            curr_dcid = doc.metadata['dcid']
            if curr_dcid not in stat_dcids:
                stat_dcids.append(curr_dcid)
        print(stat_dcids)
        result_df = dc.build_multivariate_dataframe(place_dcids,stat_dcids)
        return result_df

In [21]:
all_dcids_agent = DataCommonsDSPy()

In [22]:
all_dcids_agent(question="What is the number of patients recovered in COVID-19 from United States and India")

Prediction(
    rationale='produce the keywords. We need to identify the specific information requested, which is the number of patients recovered from COVID-19 in the United States and India.',
    places='United States; India',
    keywords='number, patients, recovered, COVID-19'
)
['country/USA', 'geoId/1836003']
[Document(page_content='COVID-19 Patients in the Intensive Care Unit', metadata={'dcid': 'Count_MedicalConditionIncident_COVID_19_PatientInICU', 'link': 'https://datacommons.org/browser/Count_MedicalConditionIncident_COVID_19_PatientInICU', 'data_source': 'Our World in Data.json'}), Document(page_content='COVID-19 Patients in the Intensive Care Unit', metadata={'dcid': 'Count_MedicalConditionIncident_COVID_19_PatientInICU', 'link': 'https://datacommons.org/browser/Count_MedicalConditionIncident_COVID_19_PatientInICU', 'data_source': 'Our World in Data.json'}), Document(page_content='COVID-19 Patients in the Intensive Care Unit', metadata={'dcid': 'Count_MedicalConditionInci

Unnamed: 0_level_0,Annual_Amount_Emissions_EPAFuelCombustionOther_NonBiogenicEmissionSource_CarbonMonoxide,Amount_EconomicActivity_GrossDomesticProduction_NAICSAdministrativeSupportWasteManagementRemediationServices_RealValue,CumulativeCount_MedicalConditionIncident_COVID_19_ConfirmedOrProbableCase,CumulativeCount_MedicalConditionIncident_COVID_19_PatientDeceased,Count_MedicalConditionIncident_COVID_19_PatientInICU,Annual_Amount_Emissions_FuelCombustionIndustrial_NonBiogenicEmissionSource_CarbonMonoxide,Mean_BarometricPressure
place,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
country/USA,2666394.0,619576000000.0,103910034.0,1188195.0,1602.0,830217.1723,
geoId/1836003,,,,,,,1015.75


In [7]:
dcid = ['Count_MedicalConditionIncident_COVID_19_PatientInICU']
place_dcids = ['country/USA', 'geoId/1836003']

dc.build_multivariate_dataframe(place_dcids,dcid)

Unnamed: 0_level_0,Count_MedicalConditionIncident_COVID_19_PatientInICU
place,Unnamed: 1_level_1
country/USA,1602
