In [None]:
!wget -q https://raw.githubusercontent.com/alikarami76/genai-workshop/refs/heads/main/db_loader.py
!wget -q https://raw.githubusercontent.com/alikarami76/genai-workshop/refs/heads/main/llm_connector.py
!wget -q https://raw.githubusercontent.com/alikarami76/genai-workshop/refs/heads/main/requirements.txt

In [None]:
%pip install -r requirements.txt

In [None]:
# Standard libs
import os, re, glob, json, sqlite3, textwrap,time
from IPython.display import Markdown, display
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv

# Data
import pandas as pd

# Visualization libs
import matplotlib
matplotlib.use("Agg")  # headless backend for PNG
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# LangChain core + tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain.tools import tool

# Optional schema helpers (list/info) if you want the agent to peek at tables/columns
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool

In [None]:
from llm_connector import langchain_groq_llm_connector

llm = langchain_groq_llm_connector("Insert your API Key here","openai/gpt-oss-120b")
llm = llm.bind(
    tools=[{"type":"browser_search"},{"type":"code_interpreter"}],
    tool_choice="auto",
    reasoning_effort="medium",
    top_p=1,
    max_completion_tokens=8192,
)


# ---- Safety knobs for SQL and results ----------------------------------------
SQL_TIMEOUT_SECONDS = 300      # wall clock per query
SQL_MAX_RETURN_ROWS = 100000   # cap result rows to avoid huge pulls

In [None]:
from db_loader import prepare_citibike_database

DB_PATH, conn, run_query = prepare_citibike_database()
SQLITE_URI = f"sqlite:///{DB_PATH}"

In [None]:
os.makedirs("outputs", exist_ok=True)

In [None]:
# -- 2) HELPERS ----------------------------------------------------------------
def _is_read_only(sql: str) -> bool:
    """Allow only SELECT and PRAGMA table_info (read-only)."""
    forbidden = r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|REPLACE|TRUNCATE|ATTACH|DETACH|VACUUM|BEGIN|COMMIT)\b"
    if re.search(forbidden, sql, flags=re.IGNORECASE):
        return False
    if re.search(r"\bPRAGMA\b", sql, re.IGNORECASE) and "table_info" not in sql.lower():
        return False
    return True

def _add_limit_if_missing(sql: str, limit: int) -> str:
    """Inject a LIMIT if the user didn't specify one on SELECT queries."""
    q = sql.strip().rstrip(";")
    if not q.lower().startswith("select"):
        return q + ";"
    if re.search(r"\blimit\b", q, flags=re.IGNORECASE):
        return q + ";"
    return f"SELECT * FROM (\n{q}\n) LIMIT {limit};"

# -- 2a) SQL tool: run a safe read-only query and save to CSV/JSON --------------
@tool("run_sqlite_query", return_direct=False)
def run_sqlite_query(query: str, limit: int = SQL_MAX_RETURN_ROWS) -> str:
    """
    Run a **read-only** SQL query against the local SQLite DB.
    Safety:
      - Only SELECT or PRAGMA table_info allowed.
      - LIMIT injected if missing (for SELECT).
    Returns a small JSON string holding paths & preview.
    """
    if not _is_read_only(query):
        return json.dumps({"error": "Only read-only queries allowed (SELECT/PRAGMA table_info)."})

    # Ensure limit for safety
    if query.strip().lower().startswith("select") and "limit" not in query.lower():
        query = _add_limit_if_missing(query, limit)

    # Execute with sqlite3; pandas makes it easy to load the result
    conn = sqlite3.connect(DB_PATH, check_same_thread=False)
    conn.execute(f"PRAGMA busy_timeout = {SQL_TIMEOUT_SECONDS*1000};")
    try:
        df = pd.read_sql_query(query, conn)
    except Exception as e:
        return json.dumps({"error": f"SQLite error: {e}"})
    finally:
        conn.close()

    # Persist to disk for the next tool
    csv_path  = os.path.join("outputs", "sql_result.csv")
    json_path = os.path.join("outputs", "sql_result.json")
    df.to_csv(csv_path, index=False)
    df.to_json(json_path, orient="records")

    preview_md = df.head(10).to_markdown(index=False) if not df.empty else "(no rows)"
    return json.dumps({
        "rows": len(df),
        "columns": list(df.columns),
        "csv_path": csv_path,
        "json_path": json_path,
        "preview": preview_md,
        "sql_used": query
    })

# -- 2b) Pandas tool: simple groupby/aggregate pipeline -------------------------
@tool("pandas_aggregate", return_direct=False)
def pandas_aggregate(csv_path: str,
                     analysis_type: str = "groupby_aggregate",
                     groupby: Optional[List[str]] = None,
                     metrics: Optional[List[Dict[str, str]]] = None,
                     filters: Optional[List[Dict[str, Any]]] = None) -> str:
    """
    Load CSV produced by the SQL tool and compute a final table.
    Supports:
      - analysis_type="groupby_aggregate" (default) with:
        * groupby: list of columns
        * metrics: list of {"column":..., "agg": one of sum|mean|count|max|min|median|std|nunique}
        * filters: optional list of {"column":..,"op":..,"value":..}
      - analysis_type="describe" as a fallback summary
    """
    if not os.path.exists(csv_path):
        return json.dumps({"error": f"CSV not found: {csv_path}"})

    df = pd.read_csv(csv_path)

    # Quick filter helper
    def _apply_filters(df, filters):
        if not filters: return df
        out = df.copy()
        for f in filters:
            col, op, val = f.get("column"), f.get("op"), f.get("value")
            if col not in out.columns:  # skip unknown columns
                continue
            if op in ("==","=","eq"):
                out = out[out[col] == val]
            elif op in (">",">=","<","<="):
                try:
                    out = out.query(f"{col} {op} @val")
                except Exception:
                    pass
            elif op == "in":
                val = val if isinstance(val, list) else [val]
                out = out[out[col].isin(val)]
            elif op == "contains":
                out = out[out[col].astype(str).str.contains(str(val), na=False)]
        return out

    if analysis_type == "describe":
        final_df = df.describe(include="all").reset_index()
    else:
        # Default: groupby_aggregate
        df2 = _apply_filters(df, filters)
        if not groupby or not metrics:
            final_df = df2.head(50)  # if underspecified, just show sample
        else:
            agg_map = {}
            for m in metrics:
                col = m.get("column")
                agg = m.get("agg","sum").lower()
                if col in df2.columns:
                    agg_map.setdefault(col, []).append(agg)
            if not agg_map:
                final_df = df2.head(50)
            else:
                out = df2.groupby(groupby).agg(agg_map)
                out.columns = ["__".join(x if isinstance(x, tuple) else (x,)) if isinstance(x, tuple) else x
                               for x in out.columns.values]
                final_df = out.reset_index()

    result_csv  = os.path.join("outputs", "final_result.csv")
    result_json = os.path.join("outputs", "final_result.json")
    final_df.to_csv(result_csv, index=False)
    final_df.to_json(result_json, orient="records")

    return json.dumps({
        "final_csv": result_csv,
        "final_json": result_json,
        "rows": len(final_df),
        "columns": list(final_df.columns),
        "preview": final_df.head(15).to_markdown(index=False)
    })


In [None]:
# =============================================================================
# NEW: Visualization tools (PNG charts via matplotlib/seaborn, Map heatmaps via folium)
# =============================================================================
import os, json, math, datetime as _dt
from typing import Optional, List, Dict, Any, Union

import pandas as pd
import numpy as np

# Headless plotting
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# Seaborn for common statistical plots (installed in most envs)
try:
    import seaborn as sns
except Exception:
    sns = None  # We'll degrade gracefully if seaborn is missing

# Folium for HTML map heatmaps
try:
    import folium
    from folium.plugins import HeatMap
except Exception as e:
    folium = None
    HeatMap = None

OUTPUTS_DIR = "outputs"
os.makedirs(OUTPUTS_DIR, exist_ok=True)

def _now_ts():
    return _dt.datetime.now().strftime("%Y%m%d_%H%M%S")

def _load_dataframe(path: str) -> pd.DataFrame:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Data not found: {path}")
    ext = os.path.splitext(path)[1].lower()
    if ext in (".csv",):
        return pd.read_csv(path)
    if ext in (".json",):
        return pd.read_json(path, orient="records")
    if ext in (".parquet",):
        return pd.read_parquet(path)
    # Fallback: try CSV
    return pd.read_csv(path)

def _apply_filters(df: pd.DataFrame, filters: Optional[List[Dict[str, Any]]]) -> pd.DataFrame:
    if not filters:
        return df
    out = df.copy()
    for f in filters:
        col, op, val = f.get("column"), f.get("op"), f.get("value")
        if col not in out.columns:
            continue
        if op in ("==","=","eq"):
            out = out[out[col] == val]
        elif op in (">",">=","<","<=","!=","<>"):
            try:
                out = out.query(f"`{col}` {op.replace('<>','!=')} @val")
            except Exception:
                pass
        elif op == "in":
            val = val if isinstance(val, list) else [val]
            out = out[out[col].isin(val)]
        elif op == "contains":
            out = out[out[col].astype(str).str.contains(str(val), na=False)]
        elif op == "between":
            try:
                lo, hi = val
                out = out[(out[col] >= lo) & (out[col] <= hi)]
            except Exception:
                pass
    return out

def _normalize_list_arg(v: Optional[Union[str, List[str]]]) -> Optional[List[str]]:
    if v is None:
        return None
    if isinstance(v, list):
        return v
    return [str(v)]

# -- 2c) Visualization tool: static charts to PNG ------------------------------
from langchain_core.tools import tool

@tool("plot_visualization", return_direct=False)
def plot_visualization(data_path: str,
                       chart_type: str,
                       x: Optional[str] = None,
                       y: Optional[Union[str, List[str]]] = None,
                       hue: Optional[str] = None,
                       agg: Optional[str] = None,
                       bins: Optional[int] = None,
                       kde: bool = False,
                       figsize: Optional[List[float]] = None,
                       title: Optional[str] = None,
                       xlabel: Optional[str] = None,
                       ylabel: Optional[str] = None,
                       rotate_xticks: Optional[int] = None,
                       filters: Optional[List[Dict[str, Any]]] = None,
                       output_filename: Optional[str] = None) -> str:
    """
    Create a static chart from a CSV/JSON table and save it as PNG in the 'outputs' directory.

    INPUT SCHEMA (named args – pass them directly, NOT as a single JSON string):
    - data_path (str): Path to CSV/JSON produced by prior tools (e.g., outputs/final_result.csv).
    - chart_type (str): One of:
        "line", "area", "bar", "barh", "scatter", "hist", "box", "violin", "pie".
    - x (str, optional): Column name for x-axis or categories.
    - y (str | List[str], optional): Column(s) for y-axis (multi-series allowed for line/bar/area).
    - hue (str, optional): Categorical column to split/colour series (mainly for seaborn charts).
    - agg (str, optional): If grouping needed for bar/pie (e.g., "sum","mean","count","median","max","min").
    - bins (int, optional): For histogram bins (chart_type="hist").
    - kde (bool, optional): Overlay KDE on histogram (requires seaborn).
    - figsize (List[float], optional): [width, height] in inches. Default [10,6].
    - title, xlabel, ylabel (str, optional): Labels and title.
    - rotate_xticks (int, optional): Rotate x tick labels by N degrees (e.g., 45).
    - filters (List[dict], optional): Row filters, same format as pandas_aggregate:
        [{"column":"col","op":"==|>|>=|<|<=|!=|in|contains|between","value":...}, ...]
    - output_filename (str, optional): PNG filename to use. Default auto: outputs/plot_<ts>.png

    OUTPUT (JSON str):
    {
      "image_path": "outputs/plot_20250101_120000.png",
      "chart_type": "bar",
      "columns_used": ["col_x","col_y1","col_y2"],
      "rows_plotted": 123,
      "notes": "Short status message"
    }

    BEHAVIOR NOTES FOR THE AGENT:
    - If the user's request mentions charts/plots/visuals that are NOT geospatial heatmaps,
      prefer this tool. Use folium heatmaps via 'map_heatmap' for lat/lng intensity maps.
    - If multiple y series are requested, pass y as a list.
    - For pie charts: set x=categorical, y=numeric, and (optionally) agg="sum" to aggregate.
    - Always point data_path to a file created earlier (e.g., from run_sqlite_query or pandas_aggregate).
    - If seaborn is missing, the tool will fall back to matplotlib where possible.
    """
    try:
        df = _load_dataframe(data_path)
        df = _apply_filters(df, filters)
        if df.empty:
            return json.dumps({"error": "No data to plot after applying filters."})

        chart_type = (chart_type or "").lower().strip()
        if chart_type not in {"line","area","bar","barh","scatter","hist","box","violin","pie"}:
            return json.dumps({"error": f"Unsupported chart_type: {chart_type}"})

        # Figure setup
        w, h = (figsize if (isinstance(figsize, list) and len(figsize) == 2) else [10, 6])
        fig = plt.figure(figsize=(float(w), float(h)))
        ax = plt.gca()

        used_cols = []
        y_cols = _normalize_list_arg(y)

        # Helper to aggregate if requested
        def maybe_aggregate(_df, _x, _ys):
            if not agg:
                return _df
            if _x and _x in _df.columns:
                agg_fun = agg.lower()
                valid = {"sum","mean","count","median","max","min","std","nunique"}
                if agg_fun not in valid:
                    return _df
                if not _ys:
                    # aggregate y-less case (e.g., count by x)
                    return _df.groupby(_x).size().reset_index(name="count")
                # aggregate over provided y columns
                use = [c for c in _ys if c in _df.columns]
                if not use:
                    return _df
                grouped = _df.groupby(_x)[use].agg(agg_fun)
                if isinstance(grouped, pd.Series):
                    grouped = grouped.to_frame(name=f"{use[0]}__{agg_fun}")
                else:
                    grouped.columns = [f"{c}__{agg_fun}" for c in grouped.columns]
                return grouped.reset_index()
            return _df

        # Build the chart
        if chart_type in {"line","area","bar","barh","scatter"}:
            if x is None:
                return json.dumps({"error": f"'x' is required for {chart_type} chart."})
            if y_cols is None and chart_type != "scatter":
                # allow single-series line/bar/area when y missing? Use all numeric, up to 3 series
                numeric = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and c != x]
                y_cols = numeric[:3] if numeric else None
            df2 = maybe_aggregate(df, x, y_cols)
            if y_cols and all(c in df2.columns for c in y_cols):
                used_cols = [x] + y_cols
                plot_df = df2[[x] + y_cols].copy()
                if chart_type in {"line","area","bar","barh"}:
                    # melt for seaborn friendly multi-series
                    m = plot_df.melt(id_vars=[x], var_name="series", value_name="value")
                    if sns:
                        if chart_type == "line":
                            g = sns.lineplot(data=m, x=x, y="value", hue="series", ax=ax)
                        elif chart_type == "area":
                            # stack area by series
                            # We'll pivot and use matplotlib stacked area
                            pvt = m.pivot(index=x, columns="series", values="value").fillna(0)
                            pvt.plot.area(ax=ax)
                        elif chart_type in {"bar","barh"}:
                            if chart_type == "bar":
                                g = sns.barplot(data=m, x=x, y="value", hue="series", ax=ax, errorbar=None)
                            else:
                                # horizontal bars
                                # For horizontal, swap axes via matplotlib
                                # Aggregate again to wide then plot horizontally
                                pvt = m.pivot(index=x, columns="series", values="value").fillna(0)
                                pvt.plot(kind="barh", ax=ax)
                        else:
                            pass
                    else:
                        # Pure matplotlib fallback
                        if chart_type == "line":
                            for s, grp in m.groupby("series"):
                                ax.plot(grp[x], grp["value"], label=s)
                            ax.legend()
                        elif chart_type == "area":
                            pvt = m.pivot(index=x, columns="series", values="value").fillna(0)
                            ax.stackplot(pvt.index, pvt.T.values, labels=pvt.columns)
                            ax.legend(loc="upper left", bbox_to_anchor=(1,1))
                        elif chart_type in {"bar","barh"}:
                            pvt = m.pivot(index=x, columns="series", values="value").fillna(0)
                            if chart_type == "bar":
                                pvt.plot(kind="bar", ax=ax)
                            else:
                                pvt.plot(kind="barh", ax=ax)
                elif chart_type == "scatter":
                    y1 = y_cols[0] if y_cols else None
                    if not y1:
                        return json.dumps({"error": "For scatter, please provide y (numeric) column."})
                    used_cols = [x, y1] + ([hue] if hue else [])
                    if sns:
                        sns.scatterplot(data=df2, x=x, y=y1, hue=hue, ax=ax)
                    else:
                        if hue and hue in df2.columns:
                            for k, grp in df2.groupby(hue):
                                ax.scatter(grp[x], grp[y1], label=str(k), alpha=0.9)
                            ax.legend()
                        else:
                            ax.scatter(df2[x], df2[y1])
            else:
                return json.dumps({"error": f"Could not resolve y columns: {y_cols}"})

        elif chart_type in {"hist"}:
            # Histogram over one numeric column (y) or choose first numeric
            col = (y_cols or [None])[0]
            if col is None:
                numeric = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
                if not numeric:
                    return json.dumps({"error": "No numeric columns available for histogram."})
                col = numeric[0]
            used_cols = [col]
            if sns:
                sns.histplot(df[col], bins=bins or 30, kde=kde, ax=ax)
            else:
                ax.hist(df[col].dropna().values, bins=bins or 30)

        elif chart_type in {"box","violin"}:
            if not x or not y_cols:
                return json.dumps({"error": f"'x' and 'y' are required for {chart_type}."})
            col = y_cols[0]
            used_cols = [x, col]
            if sns:
                if chart_type == "box":
                    sns.boxplot(data=df, x=x, y=col, hue=hue, ax=ax)
                else:
                    sns.violinplot(data=df, x=x, y=col, hue=hue, ax=ax, inner="quartile")
            else:
                # Simple matplotlib boxplot fallback (no hue)
                df.boxplot(column=col, by=x, ax=ax)
                ax.get_figure().suptitle("")  # remove automatic suptitle

        elif chart_type == "pie":
            # Expect x=categorical (labels), y=numeric (values). Can aggregate with agg.
            if not x or not y_cols:
                return json.dumps({"error": "Pie chart requires 'x' (labels) and 'y' (numeric)."})
            col = y_cols[0]
            df2 = df[[x, col]].copy()
            if agg:
                fun = getattr(df2.groupby(x)[col], agg.lower(), None)
                if callable(fun):
                    df2 = fun().reset_index()
            df2 = df2.sort_values(col, ascending=False)
            used_cols = [x, col]
            ax.pie(df2[col].values, labels=df2[x].astype(str).values, autopct="%1.1f%%", startangle=90)
            ax.axis("equal")

        # Labels & cosmetics
        if title: ax.set_title(title)
        if xlabel and hasattr(ax, "set_xlabel"): ax.set_xlabel(xlabel)
        if ylabel and hasattr(ax, "set_ylabel"): ax.set_ylabel(ylabel)
        if rotate_xticks:
            for tick in ax.get_xticklabels():
                tick.set_rotation(rotate_xticks)

        # Save PNG
        png_path = output_filename or os.path.join(OUTPUTS_DIR, f"plot_{_now_ts()}.png")
        fig.tight_layout()
        fig.savefig(png_path, dpi=160)
        plt.close(fig)

        rows_plotted = len(df) if chart_type != "hist" else df[used_cols[0]].notna().sum()
        return json.dumps({
            "image_path": png_path,
            "chart_type": chart_type,
            "columns_used": used_cols,
            "rows_plotted": int(rows_plotted),
            "notes": "Chart created successfully."
        })
    except Exception as e:
        try:
            plt.close()
        except Exception:
            pass
        return json.dumps({"error": f"plot_visualization failed: {e}"})


# -- 2d) Geospatial heatmap tool: Folium HTML heatmap --------------------------
@tool("map_heatmap", return_direct=False)
def map_heatmap(data_path: str,
                lat_col: str = "lat",
                lng_col: str = "lng",
                weight_col: Optional[str] = None,
                radius: int = 12,
                blur: int = 15,
                min_opacity: float = 0.4,
                max_zoom: int = 18,
                tiles: str = "OpenStreetMap",
                gradient: Optional[Dict[float, str]] = None,
                filters: Optional[List[Dict[str, Any]]] = None,
                output_filename: Optional[str] = None,
                zoom_start: Optional[int] = None) -> str:
    """
    Create an interactive heatmap (HTML) from latitude/longitude data using Folium.

    INPUT SCHEMA (named args – pass them directly):
    - data_path (str): CSV/JSON table with at least latitude & longitude columns.
    - lat_col (str): Name of latitude column (default: "lat"; accepts "latitude").
    - lng_col (str): Name of longitude column (default: "lng"; accepts "lon","long","longitude").
    - weight_col (str, optional): Column with weights/intensity (numeric).
    - radius (int): Point radius for heat kernel (default 12).
    - blur (int): Amount of blur (default 15).
    - min_opacity (float): Min opacity of the heatmap layer (default 0.4).
    - max_zoom (int): Max zoom level for heat layer (default 18).
    - tiles (str): Base map tiles, e.g., "OpenStreetMap", "CartoDB positron", "Stamen Toner".
    - gradient (dict, optional): {0.2:"#00f", 0.4:"#0ff", 0.6:"#0f0", 0.8:"#ff0", 1.0:"#f00"}.
    - filters (List[dict], optional): Same filter schema as other tools.
    - output_filename (str, optional): HTML filename; default auto outputs/heatmap_<ts>.html
    - zoom_start (int, optional): Initial zoom if not auto-fitting to bounds.

    OUTPUT (JSON str):
    {
      "map_html": "outputs/heatmap_20250101_120500.html",
      "points_plotted": 1420,
      "bounds": [[min_lat,min_lng],[max_lat,max_lng]],
      "notes": "Open this HTML file in a browser to view the map."
    }

    BEHAVIOR NOTES FOR THE AGENT:
    - Use this tool ONLY when the user asks for intensity/heatmap over coordinates or geospatial hotspots.
    - The output is an HTML file (NOT a PNG). Return the HTML path to the user.
    - Ensure the input table has valid lat/lng columns (common aliases are handled).
    """
    try:
        if folium is None or HeatMap is None:
            return json.dumps({"error": "folium not available in this environment."})

        df = _load_dataframe(data_path)
        df = _apply_filters(df, filters)

        # resolve column aliases
        lat_candidates = [lat_col, "latitude", "LAT", "Latitude"]
        lng_candidates = [lng_col, "lon", "long", "longitude", "LON", "Longitude"]
        lat_name = next((c for c in lat_candidates if c in df.columns), None)
        lng_name = next((c for c in lng_candidates if c in df.columns), None)
        if not lat_name or not lng_name:
            return json.dumps({"error": f"Lat/Lng columns not found. Tried {lat_candidates} / {lng_candidates}."})

        # clean and keep only valid points
        coords = df[[lat_name, lng_name] + ([weight_col] if weight_col and weight_col in df.columns else [])].copy()
        coords = coords.dropna(subset=[lat_name, lng_name])
        # remove clearly invalid coordinates
        coords = coords[(coords[lat_name].between(-90, 90)) & (coords[lng_name].between(-180, 180))]
        if coords.empty:
            return json.dumps({"error": "No valid coordinates found to plot."})

        # center / bounds
        min_lat, max_lat = coords[lat_name].min(), coords[lat_name].max()
        min_lng, max_lng = coords[lng_name].min(), coords[lng_name].max()
        center = [(min_lat + max_lat) / 2.0, (min_lng + max_lng) / 2.0]

        m = folium.Map(location=center, tiles=tiles, zoom_start=zoom_start or 5, control_scale=True)

        # Prepare data for HeatMap
        if weight_col and weight_col in coords.columns:
            heat_data = coords[[lat_name, lng_name, weight_col]].values.tolist()
        else:
            heat_data = coords[[lat_name, lng_name]].values.tolist()

        HeatMap(
            heat_data,
            radius=radius,
            blur=blur,
            min_opacity=min_opacity,
            max_zoom=max_zoom,
            gradient=gradient
        ).add_to(m)

        # fit to bounds if not provided
        try:
            m.fit_bounds([[min_lat, min_lng], [max_lat, max_lng]])
        except Exception:
            pass

        html_path = output_filename or os.path.join(OUTPUTS_DIR, f"heatmap_{_now_ts()}.html")
        m.save(html_path)

        return json.dumps({
            "map_html": html_path,
            "points_plotted": int(len(heat_data)),
            "bounds": [[float(min_lat), float(min_lng)], [float(max_lat), float(max_lng)]],
            "notes": "Open the HTML file in a browser to view the interactive heatmap."
        })
    except Exception as e:
        return json.dumps({"error": f"map_heatmap failed: {e}"})


In [None]:
# -- 3) PROMPT + AGENT ---------------------------------------------------------
# Optional: give the agent quick access to schema overview and table info.
db = SQLDatabase.from_uri(SQLITE_URI, sample_rows_in_table_info=2)
list_tables_tool = ListSQLDatabaseTool(db=db)
info_table_tool  = InfoSQLDatabaseTool(db=db)

SYSTEM_RULES = textwrap.dedent("""
You are a careful data analyst working with a local SQLite database.
You MUST follow these rules:

- SAFETY: Only run read-only SQL (SELECT or PRAGMA table_info). No writes or DDL.
- EFFICIENCY: Avoid huge pulls. If the user didn't specify limits/time windows, start small.
- CHOOSE METHOD: Decide the simplest correct analysis (groupby aggregate, describe, or basic stats).
- VISUALIZATION: 
  - Use `plot_visualization` for standard charts (line, bar, scatter, histogram, box, pie).
  - Use `map_heatmap` for geospatial heatmaps over lat/lng coordinates.
  - Only create visuals if they add value to the user's request.
- PIPELINE:
  1) In case any metrics are requested to be calculated, clarify your formula for that metrics first.
  2) For the metric's only, try to search online for refernces of those metrics and clarify your references.
  3) (Optional) Inspect schema with list_tables/info if unsure.
  4) Write a SELECT query with the minimal set of columns and filters needed.
  5) Call `run_sqlite_query` to execute and save results.
  6) If aggregation is needed, call `pandas_aggregate` with a structured spec.
  7) Create output according to exact guidelines in the OUTPUT FORMAT section below.
  8) If visualization is helpful, call `plot_visualization` for any visualization or `map_heatmap` for creating heatmaps.
- OUTPUT FORMAT:
  - PLAN SECTION: Start with a short PLAN (bullets). It should outline the assumptions you have made, as well as definition of any metrics you've created.
  - SQL CODE SECTION: Show the SQL used (fenced code). It should be clear what was the SQL queries you've used to reach to the results.
  - RESULTS SECTION: Show the of your final result and output in a table section.
  - SUMMARY SECTION: End with a crisp, non-verbose paragraph describing the data to the user.
""").strip()

# The prompt ensures the agent builds a simple plan + uses tools in-order
PROMPT = ChatPromptTemplate.from_messages([
    ("system", SYSTEM_RULES),
    ("human", "User analysis request: {analysis_request}\n\nProceed."),
    MessagesPlaceholder("agent_scratchpad"),
])

# Attach both tools to the same agent
tools = [list_tables_tool, info_table_tool, run_sqlite_query, pandas_aggregate, plot_visualization, map_heatmap]

# Create a tool-calling agent with our prompt
agent_runnable = create_tool_calling_agent(llm, tools, PROMPT)

# Wrap it in an executor so we can capture intermediate steps (the “action trace”)
executor = AgentExecutor(
    agent=agent_runnable,
    tools=tools,
    verbose=True,                     # prints a readable trace in the cell output
    return_intermediate_steps=True,   # lets us programmatically inspect tool calls
    max_iterations=16,
    handle_parsing_errors=True,
)
print("Agent ready.")


In [None]:
# -- 4) RUN THE AGENT ----------------------------------------------------------
# Try your own request here. A few examples:
# "What is the average utilization rate of the bikes for each station in Palo Alto? Meaning that how much each station's bikes are utilized by time?"
# "What are the top 10 stations by departures on weekdays 7–10 AM?"
# "Which routes between stations have the highest subscriber share compared to customers?"
# "What share of rides are done by commuters (7-10 AM weekdays) vs non-commuters?"
# "Create a heatmap of areas with the highest share of commuters (7-10 AM weekdays) starting their rides there. Meaning the areas that have the highest share of rides starting there in the morning commute time."
# "Generate a line graph comparing the utilization rate (number of daily trips per bike capacity of each station) of stations overtime between the cities during september of 2013."

analysis_request = "Create a heatmap of areas with the highest share of commuters (7-10 AM weekdays) starting their rides there. Meaning the areas that have the highest share of rides starting there in the morning commute time."

result = executor.invoke({"analysis_request": analysis_request})

# The AgentExecutor returns:
# - 'output' (final text answer)
# - 'intermediate_steps' ([(AgentAction, observation), ...]) = our action trace
print("\n=== FINAL ANSWER ===")
print(result["output"])

print("\n=== ACTION TRACE (tools & observations) ===")
for i, step in enumerate(result["intermediate_steps"], 1):
    action, observation = step
    print(f"\n-- Step {i}:")
    print("Tool:", getattr(action, "tool", "<unknown>"))
    print("Tool Input:", getattr(action, "tool_input", {}))
    # Observations can be large JSON strings; truncate for readability
    obs_str = str(observation)
    print("Observation:", (obs_str[:800] + "...") if len(obs_str) > 800 else obs_str)
    print("Final Answer:", result["output"])


In [None]:
display(Markdown(result["output"]))