In [0]:
import re
import json 
import logging 
import asyncio
import requests 
LLM_API_SECRET_SCOPE = "RAG_GEMINI_LLM"
LLM_API_SECRET_KEY = "INSERT_GEMINI_API_KEY"
FRAUD_DETECTION_NOTEBOOK_PATH = "aletha-project/fraudDetectionModel"
CREDIT_RISK_NOTEBOOK_PATH = "aletha-project/creditRiskScoreModel"
CUSTOMER_SEGMENTATION_NOTEBOOK_PATH = "aletha-project/customerSegmentationModel"

In [0]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [0]:
dbutils.secrets.put(scope=LLM_API_SECRET_SCOPE, key=LLM_API_SECRET_KEY)

In [0]:
# Try to get the current user for logging and auditing
try:
    current_user = dbutils.notebook.getContext().tags().get("user").get()
except Exception:
    current_user = "unknown_user"
logger.info(f"RAG Router initialized by user: {current_user}")

In [0]:
def classify_query(query: str) -> str:
    """
    Classifies the user query to determine which notebook to run or if it's a general inquiry.

    Args:
        query (str): The user's input query.

    Returns:
        str: The path to the relevant notebook, or "LLM" for a general inquiry.
    """
    query_lower = query.lower()

    # Define keywords for each domain
    if re.search(r'\b(fraud|detect fraud|fraud detection|fraudulent transaction)\b', query_lower):
        logger.info(f"Query '{query}' classified as: Fraud Detection")
        return FRAUD_DETECTION_NOTEBOOK_PATH
    elif re.search(r'\b(credit risk|loan risk|risk score|credit score model)\b', query_lower):
        logger.info(f"Query '{query}' classified as: Credit Risk")
        return CREDIT_RISK_NOTEBOOK_PATH
    elif re.search(r'\b(customer segmentation|customer categories|customer groups|segment customers|bucket customers)\b', query_lower):
        logger.info(f"Query '{query}' classified as: Customer Segmentation")
        return CUSTOMER_SEGMENTATION_NOTEBOOK_PATH
    else:
        logger.info(f"Query '{query}' classified as: General Inquiry (LLM)")
        return "LLM"

In [0]:
async def call_llm(prompt: str) -> str:
    """
    Calls the Gemini API to get a response for a general inquiry.

    Args:
        prompt (str): The prompt to send to the LLM.

    Returns:
        str: The LLM's generated text response.
    """
    try:
        # Retrieve API key securely from Databricks Secrets
        api_key = dbutils.secrets.get(scope=LLM_API_SECRET_SCOPE, key=LLM_API_SECRET_KEY)
        logger.info("Successfully retrieved LLM API key from Databricks Secrets.")
    except Exception as e:
        logger.error(f"Failed to retrieve LLM API key from secrets: {e}. Please ensure scope '{LLM_API_SECRET_SCOPE}' and key '{LLM_API_SECRET_KEY}' exist and are accessible.")
        return "Error: LLM API key could not be retrieved securely."

    chat_history = []
    chat_history.append({ "role": "user", "parts": [{ "text": prompt }] })
    payload = { "contents": chat_history }
    api_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key={api_key}"

    try:
        response = requests.post(api_url, headers={'Content-Type': 'application/json'}, json=payload)
        response.raise_for_status() # Raise an HTTPError for bad responses (4xx or 5xx)
        result = response.json()

        if result.get("candidates") and len(result["candidates"]) > 0 and \
           result["candidates"][0].get("content") and result["candidates"][0]["content"].get("parts") and \
           len(result["candidates"][0]["content"]["parts"]) > 0:
            llm_response_text = result["candidates"][0]["content"]["parts"][0]["text"]
            logger.info("LLM call successful.")
            return llm_response_text
        else:
            logger.warning(f"LLM did not return a valid response structure. Response: {result}")
            return "LLM did not return a valid response."
    except requests.exceptions.RequestException as e:
        logger.error(f"Error calling LLM API: {e}. Request URL: {api_url}")
        return f"Error calling LLM API: {e}"
    except json.JSONDecodeError as e:
        logger.error(f"Error decoding LLM response: {e}. Raw response: {response.text}")
        return f"Error decoding LLM response: {e}. (Check raw response in logs)"
    except Exception as e:
        logger.error(f"An unexpected error occurred during LLM call: {e}")
        return f"An unexpected error occurred with LLM call: {e}"


In [0]:
print("Enter query below and RAG will route it to the appropriate ML model or answer generally.")
print("Examples: 'Run fraud detection model', 'Tell me about credit risk', 'Segment my customers', 'What is Databricks?'")

# Get user input
user_query = input("Your query: ")

logger.info(f"Processing query: '{user_query}'")

# Classify the query
target_action = classify_query(user_query)

print(f"\nClassified action: {target_action}")
if target_action == "LLM":
    print("\n--- General Inquiry (Answering with LLM) ---")
    logger.info("Initiating LLM call for general inquiry.")
    # Call the asynchronous LLM function and wait for its result
    try:
        llm_response = asyncio.run(call_llm(user_query))
        print("\nLLM Response:")
        print(llm_response)
        logger.info("LLM inquiry completed.")
    except Exception as e:
        logger.error(f"Error during LLM inquiry execution: {e}")
        print(f"Failed to get LLM response: {e}")

elif target_action in [FRAUD_DETECTION_NOTEBOOK_PATH, CREDIT_RISK_NOTEBOOK_PATH, CUSTOMER_SEGMENTATION_NOTEBOOK_PATH]:
    print(f"\n--- Routing to Notebook: {target_action} ---")
    print("Executing the target notebook. Its output will appear below.")
    print("-" * 50) # Separator for clarity
    logger.info(f"Attempting to execute notebook: {target_action}")

    try:
        # Using %run magic command to execute the notebook directly in this session.
        %run $target_action
        print("-" * 50)
        print(f"\nSuccessfully executed notebook: {target_action}")
        logger.info(f"Notebook '{target_action}' execution completed successfully.")
        print("Please review the output above for the model results.")
    except Exception as e:
        logger.error(f"Error executing notebook '{target_action}': {e}", exc_info=True) # exc_info to log traceback
        print(f"\nError executing notebook '{target_action}': {e}")
        print("Please ensure the notebook path is correct and the notebook runs without errors.")
else:
    print("\n--- Unknown Classification ---")
    logger.warning(f"Query '{user_query}' could not be classified to a known model.")
    print("Could not classify the query to a known model. Please rephrase or try a general inquiry.")

logger.info("--- Databricks RAG Model Router Execution Finished ---")
