# 05_llm_agent_tests
Purpose: Test metadata Q&A agent and Spark optimization agent.  
Author: Janak  
Date: 2025-11-26

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import json, os, io, sys, textwrap
from typing import Dict, Any, List

spark = SparkSession.builder.getOrCreate()

# Paths (adjust if you changed them)
metadata_path = "src/metadata/metadata_schema.json"
silver_base = "/tmp/delta/silver"
gold_base = "/tmp/delta/gold"

# Load metadata (dict)
with open(metadata_path, "r") as f:
    metadata = json.load(f)

print("Loaded metadata tables:", list(metadata.keys()))


In [0]:
def pretty_print(obj):
    import pprint
    pp = pprint.PrettyPrinter(indent=2, compact=False)
    pp.pprint(obj)


In [0]:
# Basic deterministic Q&A helpers using metadata dict

def get_table_schema(table_name: str):
    t = metadata.get(table_name)
    if not t:
        return f"Table not found: {table_name}"
    return t.get("schema", {})

def get_table_description(table_name: str):
    t = metadata.get(table_name)
    if not t:
        return f"Table not found: {table_name}"
    return t.get("description", "No description provided.")

def find_tables_by_owner(owner: str) -> List[str]:
    out = [t for t,v in metadata.items() if v.get("owner","").lower() == owner.lower()]
    return out

def tables_with_freshness_less_than(hours: int) -> List[str]:
    return [t for t,v in metadata.items() if v.get("freshness_hours") is not None and v.get("freshness_hours") < hours]

def get_quality_rules(table_name: str):
    t = metadata.get(table_name)
    if not t:
        return f"Table not found: {table_name}"
    return t.get("quality_rules", {})


In [0]:
# Very small rule-based natural language parser to map a few common user intents to functions above.

def metadata_qna_agent(query: str) -> str:
    q = query.strip().lower()
    if q.startswith("show schema for") or q.startswith("schema for"):
        # format: "show schema for orders"
        parts = q.split()
        table = parts[-1]
        schema = get_table_schema(table)
        return json.dumps(schema, indent=2)
    if "who owns" in q or "owner of" in q:
        # "who owns orders" or "owner of orders"
        parts = q.split()
        table = parts[-1]
        owner = metadata.get(table, {}).get("owner", "Unknown")
        return f"Owner of `{table}`: {owner}"
    if "tables owned by" in q:
        # "tables owned by data-platform"
        owner = q.split("tables owned by")[-1].strip()
        return str(find_tables_by_owner(owner))
    if "freshness less than" in q:
        # "which tables have freshness less than 48 hours"
        import re
        m = re.search(r"freshness less than (\d+)", q)
        if m:
            hours = int(m.group(1))
            return str(tables_with_freshness_less_than(hours))
    if "quality rules for" in q or "quality rules of" in q:
        parts = q.split()
        table = parts[-1]
        return json.dumps(get_quality_rules(table), indent=2)
    # fallback: search metadata descriptions & table names for keywords
    keywords = q.split()
    matched = []
    for tname, info in metadata.items():
        text_blob = " ".join([
            tname,
            info.get("description",""),
            " ".join(info.get("schema", {}).keys())
        ]).lower()
        if any(k in text_blob for k in keywords):
            matched.append(tname)
    if matched:
        return f"Possible related tables: {matched}"
    return "I do not understand the query. Try: 'show schema for orders', 'who owns orders', 'tables owned by data-platform', 'freshness less than 48', or 'quality rules for orders'."


In [0]:
examples = [
    "Show schema for orders",
    "Who owns orders",
    "Tables owned by data-platform",
    "Which tables have freshness less than 72",
    "Quality rules for orders",
    "Show tables related to customer"
]

for q in examples:
    print("Q:", q)
    print("A:", metadata_qna_agent(q))
    print("-" * 50)


In [0]:
# OPTIONAL: Use OpenAI embeddings to build a metadata index and answer fuzzy queries.
# This is a scaffold. To enable:
# 1) Store your OPENAI_API_KEY in Databricks Secrets and set env var: os.environ["OPENAI_API_KEY"] = dbutils.secrets.get(scope, key)
# 2) Uncomment and run the code below.
#
# NOTE: Running embeddings is optional and may incur API cost.
#
# pip install openai tiktoken  (if not already installed in your environment)

"""
import os
import openai
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# set your key, example using Databricks secrets (uncomment and edit)
# os.environ["OPENAI_API_KEY"] = dbutils.secrets.get("your_scope","openai_key")
openai.api_key = os.environ.get("OPENAI_API_KEY")

# Build documents: one doc per table with description + schema text
docs = []
table_names = []
for t,info in metadata.items():
    text = info.get("description","") + " " + " ".join([f"{c}:{typ}" for c,typ in info.get("schema",{}).items()])
    docs.append(text)
    table_names.append(t)

# get embeddings for docs
resp = openai.Embedding.create(input=docs, model="text-embedding-3-small")
doc_embeddings = np.array([r['embedding'] for r in resp['data']])

def embedding_query(q):
    q_emb = np.array(openai.Embedding.create(input=[q], model="text-embedding-3-small")['data'][0]['embedding'])
    sims = cosine_similarity([q_emb], doc_embeddings)[0]
    idx = sims.argmax()
    return table_names[idx], float(sims[idx])

# Example:
print("Embedding retrieve:", embedding_query("Which table contains order status and amount?"))
"""


In [0]:
# Helpers to capture DataFrame explain plan as text and analyze
def capture_explain(df):
    # capture the output of df.explain() into a string
    old_stdout = sys.stdout
    sys.stdout = io.StringIO()
    try:
        df.explain(extended=True)
        plan = sys.stdout.getvalue()
    finally:
        sys.stdout = old_stdout
    return plan

def suggest_optimizations_from_plan(plan_text: str) -> List[str]:
    suggestions = []
    pt = plan_text.lower()
    # Rule: detect shuffle/exchange
    if "exchange" in pt or "shuffle" in pt:
        suggestions.append("Detected shuffle/exchange — consider repartitioning by join key or using a broadcast join for small dimension tables.")
    # Rule: detect cartesian (bad)
    if "cartesianproduct" in pt or "cartesian" in pt:
        suggestions.append("Cartesian product detected — check join conditions to avoid cross join.")
    # Rule: wide dependency or group-by on non-partitioned large column
    if "aggregate" in pt and "groupby" in pt or "group" in pt:
        suggestions.append("Aggregation exists — ensure proper partitioning (date) or use map-side combiners when possible.")
    # Rule: lots of file scans mention might indicate small files
    if "scan" in pt and "file" in pt:
        suggestions.append("Check file sizes — many small files can degrade performance. Consider compaction (OPTIMIZE).")
    if not suggestions:
        suggestions.append("No obvious issues detected from plan text. Consider reviewing shuffle/joins manually.")
    return suggestions


In [0]:
def optimization_agent_for_query(df):
    plan = capture_explain(df)
    print("------ PHYSICAL PLAN (truncated) ------")
    print(plan[:4000])  # show start
    suggestions = suggest_optimizations_from_plan(plan)
    return suggestions


In [0]:
# Build a sample join query using Silver tables (if they exist)
try:
    orders = spark.read.format("delta").load(f"{silver_base}/orders")
    payments = spark.read.format("delta").load(f"{silver_base}/payments")
    # a sample join that could cause shuffle
    df_test = orders.join(payments, on="order_id", how="left")
    print("Rows in joined df:", df_test.count())
    suggestions = optimization_agent_for_query(df_test)
    print("\nSuggestions:")
    for s in suggestions:
        print("-", s)
except Exception as e:
    print("Could not run optimization test (maybe silver tables missing):", e)


In [0]:
# Find enum mismatches found earlier in quality report for orders (example)
# If you ran 02_metadata_layer and saved validation_df to /tmp/delta/quality_report, load it:
try:
    v = spark.read.format("delta").load("/tmp/delta/quality_report")
    orders_violations = v.filter((col("table_name") == "orders") & (col("rule_type") == "accepted_values"))
    display(orders_violations)
    count_violations = orders_violations.agg({"issue_count":"sum"}).collect()[0][0]
    print("Total enum violations for orders (accepted_values):", count_violations)
    # Agent natural-language explanation (basic)
    if count_violations and count_violations > 0:
        print("\nAgent explanation:")
        print("The 'orders.status' column contains values not in the accepted list (pending, shipped, delivered, cancelled).")
        print("Example invalid value(s) may include 'returned'. You can either add 'returned' to accepted_values or mark returns via a separate column and update rules accordingly.")
except Exception as e:
    print("Quality report not found or error:", e)


In [0]:
examples_interactions = [
    {"query":"Show schema for orders", "answer": metadata_qna_agent("Show schema for orders")},
    {"query":"Who owns orders", "answer": metadata_qna_agent("Who owns orders")},
    {"query":"Which tables have freshness less than 72", "answer": metadata_qna_agent("Which tables have freshness less than 72")}
]

out_path = "/dbfs/tmp/agent_examples.json"
with open(out_path, "w") as f:
    json.dump(examples_interactions, f, indent=2)
print("Saved examples to:", out_path)
display(examples_interactions)


Markdown (Next Steps & How to Wire to FastAPI)
# Next steps
- To upgrade QnA: enable embeddings (OpenAI/Azure) and build vector store of metadata docs.
- To upgrade optimization: add rules from Spark history/metrics and use job metrics to estimate sizes.
- To expose agent: wire `metadata_qna_agent()` and `optimization_agent_for_query()` to FastAPI endpoints (src/api/fastapi_app.py).
- Optional: store agent logs and user queries in a Delta table for observability.
