In [55]:
# Standard libs
import os, re, glob, json, sqlite3, textwrap
from IPython.display import Markdown, display
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv

# Data
import pandas as pd

# LangChain core + tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain.tools import tool

# Optional schema helpers (list/info) if you want the agent to peek at tables/columns
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool

In [None]:
from llm_connector import langchain_groq_llm_connector

llm = langchain_groq_llm_connector("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","openai/gpt-oss-20b")
llm = llm.bind(
    tools=[{"type":"browser_search"},{"type":"code_interpreter"}],
    tool_choice="auto",
    reasoning_effort="medium",
    top_p=1,
    max_completion_tokens=8192,
)


# ---- Safety knobs for SQL and results ----------------------------------------
SQL_TIMEOUT_SECONDS = 15      # wall clock per query
SQL_MAX_RETURN_ROWS = 10000   # cap result rows to avoid huge pulls

Connecting to Groq LLM service…
Connected to Groq model 'openai/gpt-oss-20b'.


In [57]:
from db_loader import prepare_citibike_database

#db_path, conn, run_query = prepare_citibike_database()
DB_PATH = "/Users/alikarami/Documents/AI Workshop/Sample Data/database.sqlite"
SQLITE_URI = f"sqlite:///{DB_PATH}"

In [58]:
os.makedirs("outputs", exist_ok=True)

In [59]:
# -- 2) HELPERS ----------------------------------------------------------------
def _is_read_only(sql: str) -> bool:
    """Allow only SELECT and PRAGMA table_info (read-only)."""
    forbidden = r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|REPLACE|TRUNCATE|ATTACH|DETACH|VACUUM|BEGIN|COMMIT)\b"
    if re.search(forbidden, sql, flags=re.IGNORECASE):
        return False
    if re.search(r"\bPRAGMA\b", sql, re.IGNORECASE) and "table_info" not in sql.lower():
        return False
    return True

def _add_limit_if_missing(sql: str, limit: int) -> str:
    """Inject a LIMIT if the user didn't specify one on SELECT queries."""
    q = sql.strip().rstrip(";")
    if not q.lower().startswith("select"):
        return q + ";"
    if re.search(r"\blimit\b", q, flags=re.IGNORECASE):
        return q + ";"
    return f"SELECT * FROM (\n{q}\n) LIMIT {limit};"

# -- 2a) SQL tool: run a safe read-only query and save to CSV/JSON --------------
@tool("run_sqlite_query", return_direct=False)
def run_sqlite_query(query: str, limit: int = SQL_MAX_RETURN_ROWS) -> str:
    """
    Run a **read-only** SQL query against the local SQLite DB.
    Safety:
      - Only SELECT or PRAGMA table_info allowed.
      - LIMIT injected if missing (for SELECT).
    Returns a small JSON string holding paths & preview.
    """
    if not _is_read_only(query):
        return json.dumps({"error": "Only read-only queries allowed (SELECT/PRAGMA table_info)."})

    # Ensure limit for safety
    if query.strip().lower().startswith("select") and "limit" not in query.lower():
        query = _add_limit_if_missing(query, limit)

    # Execute with sqlite3; pandas makes it easy to load the result
    conn = sqlite3.connect(DB_PATH, check_same_thread=False)
    conn.execute(f"PRAGMA busy_timeout = {SQL_TIMEOUT_SECONDS*1000};")
    try:
        df = pd.read_sql_query(query, conn)
    except Exception as e:
        return json.dumps({"error": f"SQLite error: {e}"})
    finally:
        conn.close()

    # Persist to disk for the next tool
    csv_path  = os.path.join("outputs", "sql_result.csv")
    json_path = os.path.join("outputs", "sql_result.json")
    df.to_csv(csv_path, index=False)
    df.to_json(json_path, orient="records")

    preview_md = df.head(10).to_markdown(index=False) if not df.empty else "(no rows)"
    return json.dumps({
        "rows": len(df),
        "columns": list(df.columns),
        "csv_path": csv_path,
        "json_path": json_path,
        "preview": preview_md,
        "sql_used": query
    })

# -- 2b) Pandas tool: simple groupby/aggregate pipeline -------------------------
@tool("pandas_aggregate", return_direct=False)
def pandas_aggregate(csv_path: str,
                     analysis_type: str = "groupby_aggregate",
                     groupby: Optional[List[str]] = None,
                     metrics: Optional[List[Dict[str, str]]] = None,
                     filters: Optional[List[Dict[str, Any]]] = None) -> str:
    """
    Load CSV produced by the SQL tool and compute a final table.
    Supports:
      - analysis_type="groupby_aggregate" (default) with:
        * groupby: list of columns
        * metrics: list of {"column":..., "agg": one of sum|mean|count|max|min|median|std|nunique}
        * filters: optional list of {"column":..,"op":..,"value":..}
      - analysis_type="describe" as a fallback summary
    """
    if not os.path.exists(csv_path):
        return json.dumps({"error": f"CSV not found: {csv_path}"})

    df = pd.read_csv(csv_path)

    # Quick filter helper
    def _apply_filters(df, filters):
        if not filters: return df
        out = df.copy()
        for f in filters:
            col, op, val = f.get("column"), f.get("op"), f.get("value")
            if col not in out.columns:  # skip unknown columns
                continue
            if op in ("==","=","eq"):
                out = out[out[col] == val]
            elif op in (">",">=","<","<="):
                try:
                    out = out.query(f"{col} {op} @val")
                except Exception:
                    pass
            elif op == "in":
                val = val if isinstance(val, list) else [val]
                out = out[out[col].isin(val)]
            elif op == "contains":
                out = out[out[col].astype(str).str.contains(str(val), na=False)]
        return out

    if analysis_type == "describe":
        final_df = df.describe(include="all").reset_index()
    else:
        # Default: groupby_aggregate
        df2 = _apply_filters(df, filters)
        if not groupby or not metrics:
            final_df = df2.head(50)  # if underspecified, just show sample
        else:
            agg_map = {}
            for m in metrics:
                col = m.get("column")
                agg = m.get("agg","sum").lower()
                if col in df2.columns:
                    agg_map.setdefault(col, []).append(agg)
            if not agg_map:
                final_df = df2.head(50)
            else:
                out = df2.groupby(groupby).agg(agg_map)
                out.columns = ["__".join(x if isinstance(x, tuple) else (x,)) if isinstance(x, tuple) else x
                               for x in out.columns.values]
                final_df = out.reset_index()

    result_csv  = os.path.join("outputs", "final_result.csv")
    result_json = os.path.join("outputs", "final_result.json")
    final_df.to_csv(result_csv, index=False)
    final_df.to_json(result_json, orient="records")

    return json.dumps({
        "final_csv": result_csv,
        "final_json": result_json,
        "rows": len(final_df),
        "columns": list(final_df.columns),
        "preview": final_df.head(15).to_markdown(index=False)
    })


In [60]:
# -- 3) PROMPT + AGENT ---------------------------------------------------------
# Optional: give the agent quick access to schema overview and table info.
db = SQLDatabase.from_uri(SQLITE_URI, sample_rows_in_table_info=2)
list_tables_tool = ListSQLDatabaseTool(db=db)
info_table_tool  = InfoSQLDatabaseTool(db=db)

SYSTEM_RULES = textwrap.dedent("""
You are a careful data analyst working with a local SQLite database.
You MUST follow these rules:

- SAFETY: Only run read-only SQL (SELECT or PRAGMA table_info). No writes or DDL.
- EFFICIENCY: Avoid huge pulls. If the user didn't specify limits/time windows, start small.
- CHOOSE METHOD: Decide the simplest correct analysis (groupby aggregate, describe, or basic stats).
- PIPELINE:
  1) In case any metrics are requested to be calculated, clarify your formula for that metrics first.
  2) For the metric's only, try to search online for refernces of those metrics and clarify your references.
  3) (Optional) Inspect schema with list_tables/info if unsure.
  4) Write a SELECT query with the minimal set of columns and filters needed.
  5) Call `run_sqlite_query` to execute and save results.
  6) If aggregation is needed, call `pandas_aggregate` with a structured spec.
- OUTPUT FORMAT:
  - Start with a short PLAN (bullets).
  - Show the SQL used (fenced code).
  - Show a preview table of the final result.
  - End with a crisp, non-verbose paragraph describing the data to the user.
""").strip()

# The prompt ensures the agent builds a simple plan + uses tools in-order
PROMPT = ChatPromptTemplate.from_messages([
    ("system", SYSTEM_RULES),
    ("human", "User analysis request: {analysis_request}\n\nProceed."),
    MessagesPlaceholder("agent_scratchpad"),
])

# Attach both tools to the same agent
tools = [list_tables_tool, info_table_tool, run_sqlite_query, pandas_aggregate]

# Create a tool-calling agent with our prompt
agent_runnable = create_tool_calling_agent(llm, tools, PROMPT)

# Wrap it in an executor so we can capture intermediate steps (the “action trace”)
executor = AgentExecutor(
    agent=agent_runnable,
    tools=tools,
    verbose=True,                     # prints a readable trace in the cell output
    return_intermediate_steps=True,   # lets us programmatically inspect tool calls
    max_iterations=8,
    handle_parsing_errors=True,
)
print("Agent ready.")


Agent ready.


In [61]:
# -- 4) RUN THE AGENT ----------------------------------------------------------
# Try your own request here. A few examples:
# "What is the average utilization rate of the bikes for each station in Palo Alto? Meaning that how much each station's bikes are utilized by time?"
# "What are the top 10 stations by departures on weekdays 7–10 AM?"

analysis_request = "What is the average utilization rate of the bikes for each station in Palo Alto? Meaning that how much each station's bikes are utilized by time?"

result = executor.invoke({"analysis_request": analysis_request})

# The AgentExecutor returns:
# - 'output' (final text answer)
# - 'intermediate_steps' ([(AgentAction, observation), ...]) = our action trace
print("\n=== FINAL ANSWER ===")
print(result["output"])

print("\n=== ACTION TRACE (tools & observations) ===")
for i, step in enumerate(result["intermediate_steps"], 1):
    action, observation = step
    print(f"\n-- Step {i}:")
    print("Tool:", getattr(action, "tool", "<unknown>"))
    print("Tool Input:", getattr(action, "tool_input", {}))
    # Observations can be large JSON strings; truncate for readability
    obs_str = str(observation)
    print("Observation:", (obs_str[:800] + "...") if len(obs_str) > 800 else obs_str)
    print("Final Answer:", result["output"])




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[36;1m[1;3mstation, status, trip, weather[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'station, status, trip, weather'}`


[0m[33;1m[1;3m
CREATE TABLE station (
	id INTEGER, 
	name TEXT, 
	lat NUMERIC, 
	long NUMERIC, 
	dock_count INTEGER, 
	city TEXT, 
	installation_date TEXT, 
	PRIMARY KEY (id)
)

/*
2 rows from station table:
id	name	lat	long	dock_count	city	installation_date
2	San Jose Diridon Caltrain Station	37.3297320000	-121.9017820000	27	San Jose	8/6/2013
3	San Jose Civic Center	37.3306980000	-121.8889790000	15	San Jose	8/5/2013
*/


CREATE TABLE status (
	station_id INTEGER, 
	bikes_available INTEGER, 
	docks_available INTEGER, 
	time TEXT
)

/*
2 rows from status table:
station_id	bikes_available	docks_available	time
2	2	25	2013/08/29 12:06:01
2	2	25	2013/08/29 12:07:01
*/


CREATE TABLE trip (
	id INTEGER, 
	duration INTEGER, 
	start_date TE

In [63]:
display(Markdown(result["output"]))

**PLAN**  
- Clarify the utilization metric:  
  *Utilization = (bikes in use ÷ total docks) = ( dock_count – bikes_available ) ÷ dock_count.*  
- Compute the average of this metric across all status snapshots for each Palo Alto station.  
- Return station id, name, and the average utilization.

**SQL**  
```sql
SELECT
    s.id,
    s.name,
    AVG((s.dock_count - st.bikes_available) * 1.0 / s.dock_count) AS avg_utilization
FROM station s
JOIN status st ON s.id = st.station_id
WHERE s.city = 'Palo Alto'
GROUP BY s.id, s.name;
```

**Preview of Result**  

|   id | name                            |   avg_utilization |
|-----:|:--------------------------------|------------------:|
|   34 | Palo Alto Caltrain Station      |          0.558586 |
|   35 | University and Emerson          |          0.434479 |
|   36 | California Ave Caltrain Station |          0.452549 |
|   37 | Cowper at University            |          0.520756 |
|   38 | Park at Olive                   |          0.495167 |

**Interpretation**  
The table shows, for each Palo Alto station, the average proportion of its docks that were occupied by bikes during the recorded status snapshots. For example, the Palo Alto Caltrain Station had bikes in use on average about 56 % of the time, while the University and Emerson station had a lower average utilization of roughly 43 %. This gives a quick sense of how heavily each station’s bikes are being used.