In [0]:
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
import os
from mlflow import log_param, log_metric

# Load document(s) from Unity Catalog or mounted volume
loader = PyPDFLoader("/Volumes/company_docs/internal_knowledge_base.pdf")
documents = loader.load()

# Split text into chunks
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = splitter.split_documents(documents)

# Embed and index using FAISS
embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
vectorstore = FAISS.from_documents(docs, embeddings)
vectorstore.save_local("/dbfs/tmp/faiss_index")

# Set up LLM and retriever
retriever = vectorstore.as_retriever()
llm = OpenAI(temperature=0, openai_api_key=os.getenv("OPENAI_API_KEY"))

# Query intent classification
def classify_query_intent(query):
    query_lower = query.lower()
    if "segment" in query_lower and "customer" in query_lower:
        return "customer_segmentation"
    elif "media lift" in query_lower:
        return "marketing_mix_model"
    elif "churn" in query_lower:
        return "churn_prediction"
    elif "risk" in query_lower:
        return "risk_management"
    elif "fraud" in query_lower:
        return "fraud_detection"
    return "generic_rag"

# Route query based on intent
def route_query(query, retriever, llm):
    intent = classify_query_intent(query)
    log_param("query", query)
    log_param("intent", intent)

    if intent == "customer_segmentation":
        from databricks.feature_store import FeatureStoreClient
        fs = FeatureStoreClient()
        model_uri = "models:/customer_segmentation_model/Production"
        inputs = spark.sql("""
            SELECT * FROM catalog.analytics.customer_features
            WHERE age = 65 AND gender = 'female'
        """)
        preds = fs.score_batch(model_uri, inputs)
        result = preds.collect()

    elif intent == "risk_management":
        model_uri = "models:/risk_model/Production"
        inputs = spark.sql("SELECT * FROM catalog.risk.signals LIMIT 10")
        result = spark.table(model_uri).collect()  # Placeholder

    elif intent == "fraud_detection":
        model_uri = "models:/fraud_detection_model/Production"
        inputs = spark.sql("SELECT * FROM catalog.fraud.transaction_signals LIMIT 10")
        result = spark.table(model_uri).collect()  # Placeholder

    elif intent == "generic_rag":
        qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
        result = qa_chain.run(query)

    else:
        result = "Intent not supported or no route available."

    log_metric("response_length", len(str(result)))
    return result

# Sample query
query = "What are our risk signals for fraud detection?"
response = route_query(query, retriever, llm)
print("\nRESPONSE:\n", response)

# ---------------------------------------------------------
# Logging + Access Control Notes:
# ---------------------------------------------------------
# - Use Unity Catalog to set read permissions on PDF source paths
# - Enable Audit Logs via Admin Console > Audit Logs > Destination (e.g., S3, Azure Monitor)
# - Monitor and restrict API usage with IP allowlists, PAT expiration policies