In [0]:
!pip install -qq langchain
!pip install -qq langchain_core
!pip install -qq langchain_community

In [0]:
!pip install -qq langchain_openai
!pip install -qq faiss-cpu

In [0]:
import os

os.environ['OPENAI_API_KEY'] = dbutils.secrets.get(scope="sourav_secret_scope", key="OPENAI_API_KEY")

In [0]:
import os
import tempfile

from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
# from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS

import mlflow

In [0]:
# data_path = "/".join(dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get().split("/")[:-1])+"/tests"

In [0]:
with tempfile.TemporaryDirectory() as temp_dir:
    persist_dir = os.path.join(temp_dir, "faiss_index")

    # Create the vector db, persist the db to a local fs folder
    loader = TextLoader("./tests/state_of_the_union.txt")
    documents = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    docs = text_splitter.split_documents(documents)
    embeddings = OpenAIEmbeddings()
    db = FAISS.from_documents(docs, embeddings)
    db.save_local(persist_dir)

    # Create the RetrievalQA chain
    # retrievalQA = RetrievalQA.from_llm(llm=OpenAI(allow_dangerous_deserialization=True), chain_type="stuff", retriever=db.as_retriever())

    retrieval_qa = RetrievalQA.from_chain_type(llm=OpenAI(allow_dangerous_deserialization=True), chain_type="stuff", retriever=db.as_retriever())


    # Log the retrievalQA chain
    def load_retriever(persist_directory):
        embeddings = OpenAIEmbeddings()
        vectorstore = FAISS.load_local(persist_directory, embeddings,allow_dangerous_deserialization=True)
        return vectorstore.as_retriever()

    with mlflow.start_run() as run:
        logged_model = mlflow.langchain.log_model(
            retrievalQA,
            artifact_path="retrieval_qa",
            registered_model_name="RetrievalQA_model", 
            loader_fn=load_retriever,
            persist_dir=persist_dir,
        )

In [0]:
f"runs:/{ run.info.run_id }/retrieval_qa"

In [0]:
logged_model.model_uri

In [0]:
# Load the retriever chain
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)

In [0]:
print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))