In [8]:
# !pip install dspy langchain langchain_community
# !pip install rank_bm25

In [3]:
s = "… extractive families"
s = s.replace("… ","")
s

'extractive families'

In [2]:
from src.database import build_dcids_database, load_database
build_dcids_database()
dcid_collection = load_database()

BadRequestError: Error code: 400 - {'error': {'message': "'$.input' is invalid. Please check the API reference: https://platform.openai.com/docs/api-reference.", 'type': 'invalid_request_error', 'param': None, 'code': None}}

In [1]:
from src.get_place_dcids import place_dcid

In [2]:
place_dcid(["India"])

'geoId/1836003'

In [3]:
from abc import ABC, abstractmethod
from typing import Any

class DataCommonsAgent(ABC):
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.forward(*args,**kwargs)
    @abstractmethod
    def forward(self,question):
        pass

In [4]:
from langchain.schema import Document
elems = dcid_collection.get()
langchain_docs = []
for doc,metadata in zip(elems['documents'],elems['metadatas']):
    
    langchain_docs.append(Document(page_content=doc,metadata=metadata))

In [5]:
import dspy
from langchain_community.retrievers import BM25Retriever
from dotenv import load_dotenv,find_dotenv
import concurrent.futures
import datacommons_pandas as dc
from typing import Annotated, List
import chromadb.utils.embedding_functions as embedding_functions
import chromadb
import os

load_dotenv(find_dotenv(),override=True)
llm = dspy.OpenAI(model="gpt-3.5-turbo")
dspy.settings.configure(lm=llm)
"""Returns the places that the question is talking about separated by semicolon (;) and also only the noun keywords relevant to the question in a list
    Make sure that you are only outputing the noun keywords and not other things"""
class PlaceKeywordSignature(dspy.Signature):
    """Returns the places that the question is talking about separated by semicolon (;)"""
    question = dspy.InputField(prefix="Question: ",desc="Question asked by the user")
    places = dspy.OutputField(prefix="Places: ",desc="places like countries, states, towns, etc mentioned in the question separated by semicolon (;)")
    # keywords = dspy.OutputField(prefix="Keywords: ",desc="noun keywords relevant to the question in a list. DON'T include the place names and be precise")

class SelectDCIDSignature(dspy.Signature):
    """Based on the dcid and their descriptions, select the dcid(s) that are most relevant to the question. Return the relevant dcids separated by semicolon (;)
    Don't output anything else, just output the relevant dcid(s). You have to output only from the given dcids, don't output any other dcids"""
    dcids_list = dspy.InputField(prefix="DCID and Description List: ",desc="DCIDs and its corresponding description")
    relevant_dcids = dspy.OutputField(prefix="Relevant DCIDs: ",desc="relevant dcids only separated by semicolon (;)")

emb_fn = embedding_functions.OpenAIEmbeddingFunction(
    api_key=os.environ["OPENAI_API_KEY"], model_name="text-embedding-3-small"
)

class DataCommonsDSPy(dspy.Module):
    def __init__(self,dcid_collection:chromadb.Collection):
        super().__init__()
        self.datacommons_collection = dcid_collection
        self.place_keyword_llm = dspy.ChainOfThought(PlaceKeywordSignature)
        self.relevant_dcid_llm = dspy.ChainOfThought(SelectDCIDSignature)
        self.bm25_retriever = BM25Retriever.from_documents(
                langchain_docs, k=20, preprocess_func=(lambda x: x.lower())
            )
    def __call__(self,question:str, **kwargs):
        return self.forward(question, **kwargs)
    
    def _where_clause_dcids_helper_func(self,dcids_list:List[str]):
        assert len(dcids_list)>1, "Check the BM25 retriever, the number of returned documents should be more than 1 from the sparse search"
        dcid_where_clause = {"$or": [{"dcid": {"$eq": t}} for t in dcids_list]}
        return dcid_where_clause

    def forward(self,question:Annotated[str,"Question that will be answered by the DataCommons Agent"]):
        question_emb = emb_fn([question])[0]
        llm_answer = self.place_keyword_llm(question=question)
        places = llm_answer.places.split(";")
        places = [pl.strip() for pl in places]
        print()
        with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
            results = executor.map(place_dcid,places) 
        place_dcids = []
        for res in results:
            place_dcids.append(res)
        print(place_dcids)
        # Hybrid search (BM25 followed by dense retrieval)
        # bm25_docs = self.bm25_retriever.invoke(question.lower())
        # bm25_dcids = [doc.metadata['dcid'] for doc in bm25_docs]
        # dcid_where_clause = self._where_clause_dcids_helper_func(bm25_dcids)
        dense_retrieval_docs = self.datacommons_collection.query(
            query_embeddings=question_emb,
            # where=dcid_where_clause,
            n_results=5
        )
        select_dcid_str:str = ""
        for dcid_docs, dcid_metadata in zip(
            dense_retrieval_docs["documents"][0], dense_retrieval_docs["metadatas"][0]
        ):
            select_dcid_str+=f"{dcid_metadata['dcid']}: {dcid_docs}\n\n"
        
        print(select_dcid_str)
        relevant_dcid_result = self.relevant_dcid_llm(dcids_list=select_dcid_str)
        relevant_dcid_list = relevant_dcid_result.relevant_dcids.split(";")
        relevant_dcid_list = [rdl.strip() for rdl in relevant_dcid_list]
        print(relevant_dcid_list)
        result_df = dc.build_multivariate_dataframe(place_dcids,relevant_dcid_list)
        return result_df

dc_chroma = DataCommonsDSPy(dcid_collection)
# dc_chroma("What is the number of patients recovered in COVID-19 from United States and Qatar?")

In [31]:
dc_chroma("What is the air quality index in USA ?")


['geoId/1836003']
AirQualityIndex_AirPollutant: Air Quality Index

AirQualityIndex_AirPollutant_PM2.5: Air Quality Index: PM 2.5

AirQualityIndex_AirPollutant_Ozone: Air Quality Index: Ozone

AirQualityIndex_AirPollutant_NO2: Air Quality Index: Nitrogen Dioxide

AirQualityIndex_AirPollutant_SO2: Air Quality Index: Sulfur Dioxide


['AirQualityIndex_AirPollutant_PM2.5', 'AirQualityIndex_AirPollutant_Ozone', 'AirQualityIndex_AirPollutant_NO2', 'AirQualityIndex_AirPollutant_SO2']


ValueError: No data for any of specified Places and StatisticalVariables.

In [27]:
all_dcids_agent(question="What is the number of patients recovered in COVID-19 from United States and Qatar?")

Prediction(
    rationale='produce the keywords. We need to identify the number of patients recovered in COVID-19 from the United States and Qatar.',
    places='United States; Qatar',
    keywords='number, patients, recovered, COVID-19'
)
['country/USA', 'country/QAT']


Unnamed: 0_level_0,InterestRate_TreasurySecurity_1Month_ConstantMaturity,Count_MedicalConditionIncident_COVID_19_PatientInICU
place,Unnamed: 1_level_1,Unnamed: 2_level_1
country/USA,5.51,1602


In [28]:
dcid = ['AmountInterestRepayment_Debt_OfficialCreditor_Publicly']
# dcid = ['InterestRate_TreasurySecurity_1Month_ConstantMaturity', 'Count_MedicalConditionIncident_COVID_19_PatientInICU']
place_dcids = ['country/USA']

dc.build_multivariate_dataframe(place_dcids,dcid)

ValueError: No data for any of specified Places and StatisticalVariables.

In [25]:
dc_chroma(question="What is the Nonfinancial Commercial Paper Interest Rate in USA?")


['country/USA']
AmountInterestRepayment_Debt_OfficialCreditor_PubliclyGuaranteed_Bilateral_LongTermExternalDebt_LenderCountryUSA: Publicly guaranteed bilateral external debt pricipal repayment to offical creditor, United States.

AmountInterestRepayment_Debt_OfficialCreditor_Concessional_PubliclyGuaranteed_Bilateral_LongTermExternalDebt_LenderCountryUSA: Publicly guaranteed bilateral long term external debt pricipal repayment to offical creditor, concessional United States.

AmountInterestRepayment_Debt_LongTermExternalDebt_LenderCountryUSA: Long term external debt pricipal repayment to United States.

AmountInterestRepayment_Debt_OfficialCreditor_PubliclyGuaranteed_Bilateral_LongTermExternalDebt_LenderCountryPRI: Publicly guaranteed bilateral external debt pricipal repayment to offical creditor, Puerto Rico.

AmountInterestRepayment_Debt_OfficialCreditor_PubliclyGuaranteed_Bilateral_LongTermExternalDebt_LenderCountryPAN: Publicly guaranteed bilateral external debt pricipal repayment 

ValueError: No data for any of specified Places and StatisticalVariables.

In [29]:
dc_chroma(question="What is the annual consumption of Lignite coal in India?")


['geoId/1836003']
IsContaminated_Manganese_LandfillGas: Whether LandfillGas is contaminated with Manganese.

IsContaminated_Chromium_LandfillGas: Whether LandfillGas is contaminated with Chromium.

IsContaminated_Arsenic_LandfillGas: Whether LandfillGas is contaminated with Arsenic.

IsContaminated_Iron_LandfillGas: Whether LandfillGas is contaminated with Iron.

IsContaminated_Lithium_Soil: Whether Soil is contaminated with Lithium.


['IsContaminated_Manganese_LandfillGas', 'IsContaminated_Chromium_LandfillGas', 'IsContaminated_Arsenic_LandfillGas', 'IsContaminated_Iron_LandfillGas']


ValueError: No data for any of specified Places and StatisticalVariables.