In [1]:
from src.get_place_dcids import place_dcid

In [9]:
import os
from langchain.schema import Document
import ast

docs = []
for stat_files in os.listdir("src/STATS"):
    stat_file_name = ".".join(stat_files.split("_"))
    with open(os.path.join("src/STATS",stat_files), "r") as f:
        content = f.read()
    content = ast.literal_eval(content)
    for stat in content:
        docs.append(Document(page_content=stat['node_name'],metadata={'dcid': stat['node_dcid'],'link': stat['node_link'],'data_source':stat_file_name}))

In [None]:
import dspy
from langchain_community.retrievers import BM25Retriever
from dotenv import load_dotenv,find_dotenv
import concurrent.futures
load_dotenv(find_dotenv(),override=True)

class PlaceKeywordSignature(dspy.Signature):
    """Returns the places that the question is talking about separated by semicolo (;) and also the keywords relevant to the question in a list"""
    question = dspy.InputField(prefix="Question: ",desc="Question asked by the user")
    places = dspy.OutputField(prefix="Places: ",desc="places like countries, states, towns, etc mentioned in the question separated by semicolon (;)")
    keywords = dspy.OutputField(prefix="Keywords: ",desc="keywords relevant to the question in a list")

llm = dspy.OpenAI(model="gpt-3.5-turbo")
dspy.settings.configure(lm=llm)

class DataCommonsDSPy(dspy.Module):
    def __init__(self):
        self.bm25_retriever = BM25Retriever.from_documents(
                docs, k=10, preprocess_func=(lambda x: x.lower())
            )
        self.place_keyword_llm = dspy.ChainOfThought(PlaceKeywordSignature)

    def __call__(self,question,**kwargs):
        return self.forwards(question,**kwargs)

    def forward(self,question:str):
        llm_answer = self.place_keyword_llm(question)
        places = llm_answer.places.split(";")
        keywords = llm_answer.keywords

        with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
            results = executor.map(place_dcid,places) 
        place_dcids = []
        for res in results:
            place_dcids.append(res)
        bm25_docs = self.bm25_retriever.invoke(question.lower())
        for key in keywords:
            curr_keyword_docs = self.bm25_retriever.invoke(key.lower())
            bm25_docs.append(curr_keyword_docs)
        all_dcids = []
        for doc in bm25_docs:
            curr_dcid = doc.metadata['dcid']
            if curr_dcid not in all_dcids:
                all_dcids.append(curr_dcid)
        return all_dcids