In [1]:
import os
import json
import dotenv
import modal
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain import PromptTemplate, LLMChain

PERSIST_DIR = "embeddings"
DOT_ENV_PATH = ".env"

dotenv.load_dotenv(DOT_ENV_PATH)
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')

image = modal.Image.debian_slim().pip_install_from_requirements("requirements.txt")

mounts = [
    modal.Mount.from_local_file(".env", remote_path=DOT_ENV_PATH),
    modal.Mount.from_local_dir("embeddings", remote_path=PERSIST_DIR),
]

stub = modal.Stub(name="Synapse", mounts=mounts, image=image)

def load_vectordb(db_name):
    embedding = OpenAIEmbeddings(disallowed_special=())
    vectordb = Chroma(persist_directory=os.path.join(PERSIST_DIR, db_name), embedding_function=embedding)
    return vectordb.as_retriever()

def create_prompt_template():
    template = """
    You are a creativity engine. You will show a reduced representation of a corpus to draw out connections with reference to the original query of your pupil.

    For example, the following is the query:
    "{query}"

    The following is the corpus:
    "{corpus}"

    Use extreme patchwriting and mosaic writing as your style, maximally using exact quotations. Use quotes ("") and ellipses (...) to delineate between quotations of the corpus.
    Use the quotations as your canvas instead of using your own words. Do this to help the pupil have creative thoughts based on a reduced representation of the original text with reference to the given query.
    Only give the resultant patchwork and no extra explanations. Let your work speak for itself.
    """
    return PromptTemplate(template=template, input_variables=['query','corpus'])

def create_llm_chain(prompt):
    llm = OpenAI(model='gpt-3.5-turbo')
    return LLMChain(prompt=prompt, llm=llm)

def process_documents(docs, llm_chain, query):
    for doc in docs:
        corpus = doc['page_content']
        input = {'query' : query, 'corpus' : corpus}
        output = llm_chain.run(input)
        output = output.replace("\n", " ").replace('"', "").replace("'", "")
        doc['page_content'] = output

@stub.function(cpu=2, memory=2048, container_idle_timeout=300, keep_warm=1)
@modal.web_endpoint(method="POST")
def run_query(query: str, db_name: str):
    retriever = load_vectordb(db_name)
    docs = retriever.get_relevant_documents(query=query)
    docs = [json.loads(doc.json()) for doc in docs]
    prompt = create_prompt_template()
    llm_chain = create_llm_chain(prompt)
    process_documents(docs, llm_chain, query)
    return docs